Fix overflow CHECK issue with `tf.raw_ops.AddManySparseToTensorsMap`.
PiperOrigin-RevId: 369492969
Change-Id: I1d70d6c0c92e3d7a25bc3b3aa2a0c0ac9688bf81
diff --git a/tensorflow/core/kernels/sparse_tensors_map_ops.cc b/tensorflow/core/kernels/sparse_tensors_map_ops.cc
index c2c0e43..5ea5fca 100644
--- a/tensorflow/core/kernels/sparse_tensors_map_ops.cc
+++ b/tensorflow/core/kernels/sparse_tensors_map_ops.cc
@@ -23,14 +23,12 @@
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
-
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/util/overflow.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
namespace tensorflow {
@@ -254,7 +252,22 @@
errors::InvalidArgument(
"Rank of input SparseTensor should be > 1, but saw rank: ", rank));
- TensorShape tensor_input_shape(input_shape->vec<int64>());
+ auto input_shape_vec = input_shape->vec<int64>();
+ int new_num_elements = 1;
+ bool overflow_ocurred = false;
+ for (int i = 0; i < input_shape_vec.size(); i++) {
+ new_num_elements =
+ MultiplyWithoutOverflow(new_num_elements, input_shape_vec(i));
+ if (new_num_elements < 0) {
+ overflow_ocurred = true;
+ }
+ }
+
+ OP_REQUIRES(
+ context, !overflow_ocurred,
+ errors::Internal("Encountered overflow from large input shape."));
+
+ TensorShape tensor_input_shape(input_shape_vec);
gtl::InlinedVector<int64, 8> std_order(rank);
std::iota(std_order.begin(), std_order.end(), 0);
SparseTensor input_st;
@@ -262,8 +275,7 @@
tensor_input_shape, std_order,
&input_st));
- auto input_shape_t = input_shape->vec<int64>();
- const int64 N = input_shape_t(0);
+ const int64 N = input_shape_vec(0);
Tensor sparse_handles(DT_INT64, TensorShape({N}));
auto sparse_handles_t = sparse_handles.vec<int64>();
@@ -274,7 +286,7 @@
// minibatch entries.
TensorShape output_shape;
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
- input_shape_t.data() + 1,
+ input_shape_vec.data() + 1,
input_shape->NumElements() - 1, &output_shape));
// Get groups by minibatch dimension