Bug fixes for tf.sparse.cross (avoid crashes on bad inputs)
* Updated ValidateInput & CreateOutputTensors to return a Status, rather than calling OP_REQUIRES_OK. (Calling OP_REQUIRES_OK in a helper function will register the error, but won't stop execution -- this means that execution continues until we reach a fatal error, causing TF to have a hard crash).
* Updated the order of some tests to avoid crashes.
PiperOrigin-RevId: 296943120
Change-Id: Ia4e471eeae6a9b1d5e8dd9929ebf59bf20992479
diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc
index a16e34c..c7c538a 100644
--- a/tensorflow/core/kernels/sparse_cross_op.cc
+++ b/tensorflow/core/kernels/sparse_cross_op.cc
@@ -308,8 +308,8 @@
OP_REQUIRES_OK(context,
context->input_list("dense_inputs", &dense_list_in));
- ValidateInput(context, indices_list_in, values_list_in, shapes_list_in,
- dense_list_in);
+ OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in,
+ shapes_list_in, dense_list_in));
std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
GenerateColumnsFromInput(indices_list_in, values_list_in,
@@ -322,8 +322,10 @@
Tensor* shape_out;
const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
std::vector<int64> output_start_indices(batch_size);
- CreateOutputTensors(columns, batch_size, context, &indices_out, &values_out,
- &shape_out, &output_start_indices);
+ OP_REQUIRES_OK(
+ context,
+ CreateOutputTensors(columns, batch_size, context, &indices_out,
+ &values_out, &shape_out, &output_start_indices));
typename CrossTraits<HASHED_OUTPUT, InternalType>::Updater updater(
output_start_indices, indices_out, values_out);
@@ -348,83 +350,93 @@
private:
// Validates input tensors.
- void ValidateInput(OpKernelContext* context,
- const OpInputList& indices_list_in,
- const OpInputList& values_list_in,
- const OpInputList& shapes_list_in,
- const OpInputList& dense_list_in) {
+ Status ValidateInput(const OpInputList& indices_list_in,
+ const OpInputList& values_list_in,
+ const OpInputList& shapes_list_in,
+ const OpInputList& dense_list_in) {
const auto size = indices_list_in.size();
// Validates indices_list_in OpInputList.
for (int i = 0; i < size; i++) {
- OP_REQUIRES(
- context, TensorShapeUtils::IsMatrix(indices_list_in[i].shape()),
- errors::InvalidArgument(
- "Input indices should be a matrix but received shape ",
- indices_list_in[i].shape().DebugString(), " at position ", i));
- OP_REQUIRES(
- context, indices_list_in[i].shape().dim_size(1) == 2,
- errors::InvalidArgument("Expected D2 of index to be 2 got ",
- indices_list_in[i].shape().dim_size(1),
- " at position ", i));
+ if (!TensorShapeUtils::IsMatrix(indices_list_in[i].shape())) {
+ return errors::InvalidArgument(
+ "Input indices should be a matrix but received shape ",
+ indices_list_in[i].shape().DebugString(), " at position ", i);
+ }
+ if (indices_list_in[i].shape().dim_size(1) != 2) {
+ return errors::InvalidArgument("Expected D2 of index to be 2 got ",
+ indices_list_in[i].shape().dim_size(1),
+ " at position ", i);
+ }
}
// Validates values_list_in OpInputList.
- OP_REQUIRES(
- context, values_list_in.size() == size,
- errors::InvalidArgument("Expected ", size, " input values, got ",
- values_list_in.size()));
+ if (values_list_in.size() != size) {
+ return errors::InvalidArgument("Expected ", size, " input values, got ",
+ values_list_in.size());
+ }
for (int i = 0; i < size; i++) {
- OP_REQUIRES(
- context, TensorShapeUtils::IsVector(values_list_in[i].shape()),
- errors::InvalidArgument(
- "Input values should be a std::vector but received shape ",
- values_list_in[i].shape().DebugString(), " at position ", i));
- OP_REQUIRES(
- context,
- indices_list_in[i].shape().dim_size(0) ==
- values_list_in[i].shape().dim_size(0),
- errors::InvalidArgument(
- "Expected size of values to be ",
- indices_list_in[i].shape().dim_size(0), " got ",
- values_list_in[i].shape().dim_size(0), " at position ", i));
+ if (!TensorShapeUtils::IsVector(values_list_in[i].shape())) {
+ return errors::InvalidArgument(
+ "Input values should be a vector but received shape ",
+ values_list_in[i].shape().DebugString(), " at position ", i);
+ }
+ if (indices_list_in[i].shape().dim_size(0) !=
+ values_list_in[i].shape().dim_size(0)) {
+ return errors::InvalidArgument(
+ "Expected size of values to be ",
+ indices_list_in[i].shape().dim_size(0), " got ",
+ values_list_in[i].shape().dim_size(0), " at position ", i);
+ }
}
// Validates shapes_list_in OpInputList
- OP_REQUIRES(
- context, shapes_list_in.size() == size,
- errors::InvalidArgument("Expected ", size, " input shapes, got ",
- shapes_list_in.size()));
- const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
+ if (shapes_list_in.size() != size) {
+ return errors::InvalidArgument("Expected ", size, " input shapes, got ",
+ shapes_list_in.size());
+ }
for (int i = 0; i < size; i++) {
- OP_REQUIRES(
- context, TensorShapeUtils::IsVector(shapes_list_in[i].shape()),
- errors::InvalidArgument(
- "Input shapes should be a std::vector but received shape ",
- shapes_list_in[i].shape().DebugString(), " at position ", i));
+ if (!TensorShapeUtils::IsVector(shapes_list_in[i].shape())) {
+ return errors::InvalidArgument(
+ "Input shapes should be a vector but received shape ",
+ shapes_list_in[i].shape().DebugString(), " at position ", i);
+ }
- OP_REQUIRES(
- context, shapes_list_in[i].vec<int64>().size() == 2,
- errors::InvalidArgument("shape should imply a 2D tensor, but got ",
- shapes_list_in[i].shape().DebugString(),
- " at position ", i));
- OP_REQUIRES(context, shapes_list_in[i].vec<int64>()(0) == batch_size,
- errors::InvalidArgument(
- "Expected batch size ", batch_size, " got ",
- shapes_list_in[i].vec<int64>()(0), " at position ", i));
+ if (shapes_list_in[i].vec<int64>().size() != 2) {
+ return errors::InvalidArgument(
+ "shape should imply a 2D tensor, but got ",
+ shapes_list_in[i].shape().DebugString(), " at position ", i);
+ }
}
// Validates dense_list_in OpInputList
for (int i = 0; i < dense_list_in.size(); ++i) {
- OP_REQUIRES(
- context, TensorShapeUtils::IsMatrix(dense_list_in[i].shape()),
- errors::InvalidArgument(
- "Dense inputs should be a matrix but received shape ",
- dense_list_in[i].shape().DebugString(), " at position ", i));
- OP_REQUIRES(context, dense_list_in[i].dim_size(0) == batch_size,
- errors::InvalidArgument("Expected batch size ", batch_size,
- " got ", dense_list_in[i].dim_size(0),
- " at dense tensor ", i));
+ if (!TensorShapeUtils::IsMatrix(dense_list_in[i].shape())) {
+ return errors::InvalidArgument(
+ "Dense inputs should be a matrix but received shape ",
+ dense_list_in[i].shape().DebugString(), " at position ", i);
+ }
}
+
+ // Validates batch sizes. (Note: we do this after validating the input
+ // shapes, because CalculateBatchSize() depends on inputs having valid
+ // shapes).
+ const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
+ for (int i = 0; i < size; i++) {
+ if (shapes_list_in[i].vec<int64>()(0) != batch_size) {
+ return errors::InvalidArgument(
+ "Expected batch size ", batch_size, " got ",
+ shapes_list_in[i].vec<int64>()(0), " at position ", i);
+ }
+ }
+ for (int i = 0; i < dense_list_in.size(); ++i) {
+ if (dense_list_in[i].dim_size(0) != batch_size) {
+ return errors::InvalidArgument("Expected batch size ", batch_size,
+ " got ", dense_list_in[i].dim_size(0),
+ " at dense tensor ", i);
+ }
+ }
+
+ return Status::OK();
}
// Calculate the batch size from either the shapes input or the dense input.
@@ -500,7 +512,7 @@
// the output SparseTensor.
// It also output_start_indices which contains the start indices for each
// input in the output SparseTensor.
- void CreateOutputTensors(
+ Status CreateOutputTensors(
const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
columns,
int64 batch_size, OpKernelContext* context, Tensor** indices_out,
@@ -518,19 +530,19 @@
}
// Allocates tensors.
- OP_REQUIRES_OK(context,
- context->allocate_output(
- 0, TensorShape({cross_count_total, 2}), indices_out));
- OP_REQUIRES_OK(context,
- context->allocate_output(1, TensorShape({cross_count_total}),
- values_out));
- OP_REQUIRES_OK(context,
- context->allocate_output(2, TensorShape({2}), shape_out));
+ TF_RETURN_IF_ERROR(context->allocate_output(
+ 0, TensorShape({cross_count_total, 2}), indices_out));
+ TF_RETURN_IF_ERROR(context->allocate_output(
+ 1, TensorShape({cross_count_total}), values_out));
+ TF_RETURN_IF_ERROR(
+ context->allocate_output(2, TensorShape({2}), shape_out));
// Sets shape.
auto shape_vec = (*shape_out)->vec<int64>();
shape_vec(0) = batch_size;
shape_vec(1) = max_cross_count;
+
+ return Status::OK();
}
// Returns number of crosses for a given batch_index
diff --git a/tensorflow/python/kernel_tests/sparse_cross_op_test.py b/tensorflow/python/kernel_tests/sparse_cross_op_test.py
index 566bbb5..5037f82 100644
--- a/tensorflow/python/kernel_tests/sparse_cross_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_cross_op_test.py
@@ -23,8 +23,10 @@
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
@@ -410,6 +412,52 @@
constant_op.constant(values, value_type, [len(indices)]),
constant_op.constant(shape, dtypes.int64))
+ def test_invalid_sparse_tensors(self):
+ # Test validation of invalid SparseTensors. The SparseTensor constructor
+ # prevents us from creating invalid SparseTensors (eps. in eager mode),
+ # so we create valid SparseTensors and then modify them to be invalid.
+
+ st1 = sparse_tensor.SparseTensor([[0, 0]], [0], [2, 2])
+ st1._indices = array_ops.zeros([], dtypes.int64)
+ with self.assertRaisesRegexp((errors.InvalidArgumentError, ValueError),
+ 'Input indices should be a matrix'):
+ self.evaluate(sparse_ops.sparse_cross([st1]))
+
+ st2 = sparse_tensor.SparseTensor([[0, 0]], [0], [2, 2])
+ st2._values = array_ops.zeros([], dtypes.int64)
+ with self.assertRaisesRegexp((errors.InvalidArgumentError, ValueError),
+ 'Input values should be a vector'):
+ self.evaluate(sparse_ops.sparse_cross([st2]))
+
+ st3 = sparse_tensor.SparseTensor([[0, 0]], [0], [2, 2])
+ st3._dense_shape = array_ops.zeros([], dtypes.int64)
+ with self.assertRaisesRegexp((errors.InvalidArgumentError, ValueError),
+ 'Input shapes should be a vector'):
+ self.evaluate(sparse_ops.sparse_cross([st3]))
+
+ def test_bad_tensor_shapes(self):
+ # All inputs must be 2D.
+ with self.assertRaisesRegexp((errors.InvalidArgumentError, ValueError),
+ 'Expected D2 of index to be 2'):
+ st = sparse_tensor.SparseTensor([[0]], [0], [10]) # 1D SparseTensor
+ self.evaluate(sparse_ops.sparse_cross([st]))
+
+ with self.assertRaisesRegexp((errors.InvalidArgumentError, ValueError),
+ 'Dense inputs should be a matrix'):
+ dt = array_ops.zeros([0]) # 1D DenseTensor.
+ self.evaluate(sparse_ops.sparse_cross([dt]))
+
+ def test_batch_size_mismatch(self):
+ st1 = sparse_tensor.SparseTensor([[0, 0]], [0], [10, 10]) # batch size 10
+ st2 = sparse_tensor.SparseTensor([[0, 0]], [0], [7, 10]) # batch size 7
+ dt = array_ops.zeros([5, 0]) # batch size 5
+ with self.assertRaisesRegexp((errors.InvalidArgumentError, ValueError),
+ 'Expected batch size'):
+ self.evaluate(sparse_ops.sparse_cross([st1, dt]))
+ with self.assertRaisesRegexp((errors.InvalidArgumentError, ValueError),
+ 'Expected batch size'):
+ self.evaluate(sparse_ops.sparse_cross([st1, st2]))
+
if __name__ == '__main__':
test.main()