Make SparseFillEmptyRows validate that the length of `values` must be equal to the number of index tuples.
PiperOrigin-RevId: 399969549
Change-Id: I3c2f2ca1c1d2cc88bb5951c6958b38c16e9436c8
diff --git a/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc b/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
index e0c7e18..59eb607 100644
--- a/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
+++ b/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
@@ -24,11 +24,13 @@
#include <vector>
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
namespace tensorflow {
@@ -223,6 +225,12 @@
values_t.shape().DebugString()),
done);
OP_REQUIRES_ASYNC(
+ context, indices_t.dim_size(0) == values_t.dim_size(0),
+ errors::InvalidArgument("The length of `values` (", values_t.dim_size(0),
+ ") must match the first dimension of `indices` (",
+ indices_t.dim_size(0), ")."),
+ done);
+ OP_REQUIRES_ASYNC(
context, TensorShapeUtils::IsScalar(default_value_t.shape()),
errors::InvalidArgument("default_value must be a scalar, saw: ",
default_value_t.shape().DebugString()),