Bug fixes for tf.sparse.cross (avoid crashes on bad inputs)

* Updated ValidateInput & CreateOutputTensors to return a Status, rather than calling OP_REQUIRES_OK.  (Calling OP_REQUIRES_OK in a helper function will register the error, but won't stop execution -- this means that execution continues until we reach a fatal error, causing TF to have a hard crash).

* Updated the order of some tests to avoid crashes.

PiperOrigin-RevId: 296943120
Change-Id: Ia4e471eeae6a9b1d5e8dd9929ebf59bf20992479
diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc
index a16e34c..c7c538a 100644
--- a/tensorflow/core/kernels/sparse_cross_op.cc
+++ b/tensorflow/core/kernels/sparse_cross_op.cc
@@ -308,8 +308,8 @@
     OP_REQUIRES_OK(context,
                    context->input_list("dense_inputs", &dense_list_in));
 
-    ValidateInput(context, indices_list_in, values_list_in, shapes_list_in,
-                  dense_list_in);
+    OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in,
+                                          shapes_list_in, dense_list_in));
 
     std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
         GenerateColumnsFromInput(indices_list_in, values_list_in,
@@ -322,8 +322,10 @@
     Tensor* shape_out;
     const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
     std::vector<int64> output_start_indices(batch_size);
-    CreateOutputTensors(columns, batch_size, context, &indices_out, &values_out,
-                        &shape_out, &output_start_indices);
+    OP_REQUIRES_OK(
+        context,
+        CreateOutputTensors(columns, batch_size, context, &indices_out,
+                            &values_out, &shape_out, &output_start_indices));
 
     typename CrossTraits<HASHED_OUTPUT, InternalType>::Updater updater(
         output_start_indices, indices_out, values_out);
@@ -348,83 +350,93 @@
 
  private:
   // Validates input tensors.
-  void ValidateInput(OpKernelContext* context,
-                     const OpInputList& indices_list_in,
-                     const OpInputList& values_list_in,
-                     const OpInputList& shapes_list_in,
-                     const OpInputList& dense_list_in) {
+  Status ValidateInput(const OpInputList& indices_list_in,
+                       const OpInputList& values_list_in,
+                       const OpInputList& shapes_list_in,
+                       const OpInputList& dense_list_in) {
     const auto size = indices_list_in.size();
     // Validates indices_list_in OpInputList.
     for (int i = 0; i < size; i++) {
-      OP_REQUIRES(
-          context, TensorShapeUtils::IsMatrix(indices_list_in[i].shape()),
-          errors::InvalidArgument(
-              "Input indices should be a matrix but received shape ",
-              indices_list_in[i].shape().DebugString(), " at position ", i));
-      OP_REQUIRES(
-          context, indices_list_in[i].shape().dim_size(1) == 2,
-          errors::InvalidArgument("Expected D2 of index to be 2 got ",
-                                  indices_list_in[i].shape().dim_size(1),
-                                  " at position ", i));
+      if (!TensorShapeUtils::IsMatrix(indices_list_in[i].shape())) {
+        return errors::InvalidArgument(
+            "Input indices should be a matrix but received shape ",
+            indices_list_in[i].shape().DebugString(), " at position ", i);
+      }
+      if (indices_list_in[i].shape().dim_size(1) != 2) {
+        return errors::InvalidArgument("Expected D2 of index to be 2 got ",
+                                       indices_list_in[i].shape().dim_size(1),
+                                       " at position ", i);
+      }
     }
 
     // Validates values_list_in OpInputList.
-    OP_REQUIRES(
-        context, values_list_in.size() == size,
-        errors::InvalidArgument("Expected ", size, " input values, got ",
-                                values_list_in.size()));
+    if (values_list_in.size() != size) {
+      return errors::InvalidArgument("Expected ", size, " input values, got ",
+                                     values_list_in.size());
+    }
     for (int i = 0; i < size; i++) {
-      OP_REQUIRES(
-          context, TensorShapeUtils::IsVector(values_list_in[i].shape()),
-          errors::InvalidArgument(
-              "Input values should be a std::vector but received shape ",
-              values_list_in[i].shape().DebugString(), " at position ", i));
-      OP_REQUIRES(
-          context,
-          indices_list_in[i].shape().dim_size(0) ==
-              values_list_in[i].shape().dim_size(0),
-          errors::InvalidArgument(
-              "Expected size of values to be ",
-              indices_list_in[i].shape().dim_size(0), " got ",
-              values_list_in[i].shape().dim_size(0), " at position ", i));
+      if (!TensorShapeUtils::IsVector(values_list_in[i].shape())) {
+        return errors::InvalidArgument(
+            "Input values should be a vector but received shape ",
+            values_list_in[i].shape().DebugString(), " at position ", i);
+      }
+      if (indices_list_in[i].shape().dim_size(0) !=
+          values_list_in[i].shape().dim_size(0)) {
+        return errors::InvalidArgument(
+            "Expected size of values to be ",
+            indices_list_in[i].shape().dim_size(0), " got ",
+            values_list_in[i].shape().dim_size(0), " at position ", i);
+      }
     }
 
     // Validates shapes_list_in OpInputList
-    OP_REQUIRES(
-        context, shapes_list_in.size() == size,
-        errors::InvalidArgument("Expected ", size, " input shapes, got ",
-                                shapes_list_in.size()));
-    const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
+    if (shapes_list_in.size() != size) {
+      return errors::InvalidArgument("Expected ", size, " input shapes, got ",
+                                     shapes_list_in.size());
+    }
     for (int i = 0; i < size; i++) {
-      OP_REQUIRES(
-          context, TensorShapeUtils::IsVector(shapes_list_in[i].shape()),
-          errors::InvalidArgument(
-              "Input shapes should be a std::vector but received shape ",
-              shapes_list_in[i].shape().DebugString(), " at position ", i));
+      if (!TensorShapeUtils::IsVector(shapes_list_in[i].shape())) {
+        return errors::InvalidArgument(
+            "Input shapes should be a vector but received shape ",
+            shapes_list_in[i].shape().DebugString(), " at position ", i);
+      }
 
-      OP_REQUIRES(
-          context, shapes_list_in[i].vec<int64>().size() == 2,
-          errors::InvalidArgument("shape should imply a 2D tensor, but got ",
-                                  shapes_list_in[i].shape().DebugString(),
-                                  " at position ", i));
-      OP_REQUIRES(context, shapes_list_in[i].vec<int64>()(0) == batch_size,
-                  errors::InvalidArgument(
-                      "Expected batch size ", batch_size, " got ",
-                      shapes_list_in[i].vec<int64>()(0), " at position ", i));
+      if (shapes_list_in[i].vec<int64>().size() != 2) {
+        return errors::InvalidArgument(
+            "shape should imply a 2D tensor, but got ",
+            shapes_list_in[i].shape().DebugString(), " at position ", i);
+      }
     }
 
     // Validates dense_list_in OpInputList
     for (int i = 0; i < dense_list_in.size(); ++i) {
-      OP_REQUIRES(
-          context, TensorShapeUtils::IsMatrix(dense_list_in[i].shape()),
-          errors::InvalidArgument(
-              "Dense inputs should be a matrix but received shape ",
-              dense_list_in[i].shape().DebugString(), " at position ", i));
-      OP_REQUIRES(context, dense_list_in[i].dim_size(0) == batch_size,
-                  errors::InvalidArgument("Expected batch size ", batch_size,
-                                          " got ", dense_list_in[i].dim_size(0),
-                                          " at dense tensor ", i));
+      if (!TensorShapeUtils::IsMatrix(dense_list_in[i].shape())) {
+        return errors::InvalidArgument(
+            "Dense inputs should be a matrix but received shape ",
+            dense_list_in[i].shape().DebugString(), " at position ", i);
+      }
     }
+
+    // Validates batch sizes.  (Note: we do this after validating the input
+    // shapes, because CalculateBatchSize() depends on inputs having valid
+    // shapes).
+    const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
+    for (int i = 0; i < size; i++) {
+      if (shapes_list_in[i].vec<int64>()(0) != batch_size) {
+        return errors::InvalidArgument(
+            "Expected batch size ", batch_size, " got ",
+            shapes_list_in[i].vec<int64>()(0), " at position ", i);
+      }
+    }
+    for (int i = 0; i < dense_list_in.size(); ++i) {
+      if (dense_list_in[i].dim_size(0) != batch_size) {
+        return errors::InvalidArgument("Expected batch size ", batch_size,
+                                       " got ", dense_list_in[i].dim_size(0),
+                                       " at dense tensor ", i);
+      }
+    }
+
+    return Status::OK();
   }
 
   // Calculate the batch size from either the shapes input or the dense input.
@@ -500,7 +512,7 @@
   // the output SparseTensor.
   // It also output_start_indices which contains the start indices for each
   // input in the output SparseTensor.
-  void CreateOutputTensors(
+  Status CreateOutputTensors(
       const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
           columns,
       int64 batch_size, OpKernelContext* context, Tensor** indices_out,
@@ -518,19 +530,19 @@
     }
 
     // Allocates tensors.
-    OP_REQUIRES_OK(context,
-                   context->allocate_output(
-                       0, TensorShape({cross_count_total, 2}), indices_out));
-    OP_REQUIRES_OK(context,
-                   context->allocate_output(1, TensorShape({cross_count_total}),
-                                            values_out));
-    OP_REQUIRES_OK(context,
-                   context->allocate_output(2, TensorShape({2}), shape_out));
+    TF_RETURN_IF_ERROR(context->allocate_output(
+        0, TensorShape({cross_count_total, 2}), indices_out));
+    TF_RETURN_IF_ERROR(context->allocate_output(
+        1, TensorShape({cross_count_total}), values_out));
+    TF_RETURN_IF_ERROR(
+        context->allocate_output(2, TensorShape({2}), shape_out));
 
     // Sets shape.
     auto shape_vec = (*shape_out)->vec<int64>();
     shape_vec(0) = batch_size;
     shape_vec(1) = max_cross_count;
+
+    return Status::OK();
   }
 
   // Returns number of crosses for a given batch_index
diff --git a/tensorflow/python/kernel_tests/sparse_cross_op_test.py b/tensorflow/python/kernel_tests/sparse_cross_op_test.py
index 566bbb5..5037f82 100644
--- a/tensorflow/python/kernel_tests/sparse_cross_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_cross_op_test.py
@@ -23,8 +23,10 @@
 from tensorflow.python.client import session
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import sparse_ops
 from tensorflow.python.platform import test
 
@@ -410,6 +412,52 @@
         constant_op.constant(values, value_type, [len(indices)]),
         constant_op.constant(shape, dtypes.int64))
 
+  def test_invalid_sparse_tensors(self):
+    # Test validation of invalid SparseTensors.  The SparseTensor constructor
+    # prevents us from creating invalid SparseTensors (eps. in eager mode),
+    # so we create valid SparseTensors and then modify them to be invalid.
+
+    st1 = sparse_tensor.SparseTensor([[0, 0]], [0], [2, 2])
+    st1._indices = array_ops.zeros([], dtypes.int64)
+    with self.assertRaisesRegexp((errors.InvalidArgumentError, ValueError),
+                                 'Input indices should be a matrix'):
+      self.evaluate(sparse_ops.sparse_cross([st1]))
+
+    st2 = sparse_tensor.SparseTensor([[0, 0]], [0], [2, 2])
+    st2._values = array_ops.zeros([], dtypes.int64)
+    with self.assertRaisesRegexp((errors.InvalidArgumentError, ValueError),
+                                 'Input values should be a vector'):
+      self.evaluate(sparse_ops.sparse_cross([st2]))
+
+    st3 = sparse_tensor.SparseTensor([[0, 0]], [0], [2, 2])
+    st3._dense_shape = array_ops.zeros([], dtypes.int64)
+    with self.assertRaisesRegexp((errors.InvalidArgumentError, ValueError),
+                                 'Input shapes should be a vector'):
+      self.evaluate(sparse_ops.sparse_cross([st3]))
+
+  def test_bad_tensor_shapes(self):
+    # All inputs must be 2D.
+    with self.assertRaisesRegexp((errors.InvalidArgumentError, ValueError),
+                                 'Expected D2 of index to be 2'):
+      st = sparse_tensor.SparseTensor([[0]], [0], [10])  # 1D SparseTensor
+      self.evaluate(sparse_ops.sparse_cross([st]))
+
+    with self.assertRaisesRegexp((errors.InvalidArgumentError, ValueError),
+                                 'Dense inputs should be a matrix'):
+      dt = array_ops.zeros([0])  # 1D DenseTensor.
+      self.evaluate(sparse_ops.sparse_cross([dt]))
+
+  def test_batch_size_mismatch(self):
+    st1 = sparse_tensor.SparseTensor([[0, 0]], [0], [10, 10])  # batch size 10
+    st2 = sparse_tensor.SparseTensor([[0, 0]], [0], [7, 10])  # batch size 7
+    dt = array_ops.zeros([5, 0])  # batch size 5
+    with self.assertRaisesRegexp((errors.InvalidArgumentError, ValueError),
+                                 'Expected batch size'):
+      self.evaluate(sparse_ops.sparse_cross([st1, dt]))
+    with self.assertRaisesRegexp((errors.InvalidArgumentError, ValueError),
+                                 'Expected batch size'):
+      self.evaluate(sparse_ops.sparse_cross([st1, st2]))
+
 
 if __name__ == '__main__':
   test.main()