Check for tensors to be vectors in `BoostedTreesSparseAggregateStatsOp`.
Calling `tensor->vec` should only happen after checking that the tensor shape implies a vector. Otherwise, we can get denial of service via `CHECK`-fails
PiperOrigin-RevId: 410960878
Change-Id: I7b26bec796cbaebde4696862eb855160402b4b0d
diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
index eb1eab4..c6fae8d 100644
--- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
@@ -1611,6 +1611,10 @@
// node_ids.
const Tensor* node_ids_t;
OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsVector(node_ids_t->shape()),
+ errors::InvalidArgument("node_ids must be a vector, received shape ",
+ node_ids_t->shape().DebugString()));
const auto node_ids = node_ids_t->vec<int32>();
// gradients.