Check tensor to be scalar before accessing it via `scalar<T>`
Prevents `CHECK`-fail if access is wrong.
PiperOrigin-RevId: 414159318
Change-Id: I4d610d24a8236959960cc176839888051c657ff3
diff --git a/tensorflow/core/kernels/boosted_trees/resource_ops.cc b/tensorflow/core/kernels/boosted_trees/resource_ops.cc
index 72a8fc4..ab03924 100644
--- a/tensorflow/core/kernels/boosted_trees/resource_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/resource_ops.cc
@@ -56,6 +56,12 @@
const Tensor* tree_ensemble_serialized_t;
OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized",
&tree_ensemble_serialized_t));
+ OP_REQUIRES(
+ context,
+ TensorShapeUtils::IsScalar(tree_ensemble_serialized_t->shape()),
+ errors::InvalidArgument(
+ "tree_ensemble_serialized must be a scalar, got a tensor of shape ",
+ tree_ensemble_serialized_t->shape().DebugString()));
std::unique_ptr<BoostedTreesEnsembleResource> result(
new BoostedTreesEnsembleResource());
if (!result->InitFromSerialized(