Ensure `tree_complexity` is scalar in `BoostedTreesCalculateBestGainsPerFeature`

PiperOrigin-RevId: 411074185
Change-Id: Ida9f3c0ead019a6931a1157b0cfb161a9efa9e64
diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
index 503c192..1c85493 100644
--- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
@@ -98,6 +98,10 @@
     const Tensor* tree_complexity_t;
     OP_REQUIRES_OK(context,
                    context->input("tree_complexity", &tree_complexity_t));
+    OP_REQUIRES(context, TensorShapeUtils::IsScalar(tree_complexity_t->shape()),
+                errors::InvalidArgument(
+                    "tree_complexity must be a scalar, got a tensor of shape ",
+                    tree_complexity_t->shape().DebugString()));
     const auto tree_complexity = tree_complexity_t->scalar<float>()();
     const Tensor* min_node_weight_t;
     OP_REQUIRES_OK(context,