Check tensor to be scalar before accessing it via `scalar<T>`

Prevents `CHECK`-fail if access is wrong.

PiperOrigin-RevId: 414159318
Change-Id: I4d610d24a8236959960cc176839888051c657ff3
diff --git a/tensorflow/core/kernels/boosted_trees/resource_ops.cc b/tensorflow/core/kernels/boosted_trees/resource_ops.cc
index 72a8fc4..ab03924 100644
--- a/tensorflow/core/kernels/boosted_trees/resource_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/resource_ops.cc
@@ -56,6 +56,12 @@
     const Tensor* tree_ensemble_serialized_t;
     OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized",
                                            &tree_ensemble_serialized_t));
+    OP_REQUIRES(
+        context,
+        TensorShapeUtils::IsScalar(tree_ensemble_serialized_t->shape()),
+        errors::InvalidArgument(
+            "tree_ensemble_serialized must be a scalar, got a tensor of shape ",
+            tree_ensemble_serialized_t->shape().DebugString()));
     std::unique_ptr<BoostedTreesEnsembleResource> result(
         new BoostedTreesEnsembleResource());
     if (!result->InitFromSerialized(