Backward compatible api change:BoostedTreesUpdateEnsembleV2 works on list of feature_ids.

PiperOrigin-RevId: 289687663
Change-Id: I5d12d044ae42fc34f03a3eaa357bf71b7cb06eec
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesUpdateEnsembleV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesUpdateEnsembleV2.pbtxt
index 26f1f20..66404dc 100644
--- a/tensorflow/core/api_def/base_api/api_def_BoostedTreesUpdateEnsembleV2.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesUpdateEnsembleV2.pbtxt
@@ -93,6 +93,14 @@
 scalar, dimension of the logits
 END
   }
+  attr {
+    name: "num_groups"
+    description: <<END
+Number of groups of split information to process, where a group contains feature
+ids that are processed together in BoostedTreesCalculateBestFeatureSplitOpV2.
+INFERRED.
+END
+  }
   summary: "Updates the tree ensemble by adding a layer to the last tree being grown"
   description: <<END
 or by starting a new tree.
diff --git a/tensorflow/core/kernels/boosted_trees/training_ops.cc b/tensorflow/core/kernels/boosted_trees/training_ops.cc
index 7816c2c..95fb179 100644
--- a/tensorflow/core/kernels/boosted_trees/training_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/training_ops.cc
@@ -269,9 +269,11 @@
     OP_REQUIRES_OK(context,
                    context->input_list("split_types", &split_types_list));
 
-    const Tensor* feature_ids_t;
-    OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
-    const auto feature_ids = feature_ids_t->vec<int32>();
+    OpInputList feature_ids_list;
+    OP_REQUIRES_OK(context,
+                   context->input_list("feature_ids", &feature_ids_list));
+    // TODO(crawles): Read groups of feature ids and find best splits among all.
+    const auto feature_ids = feature_ids_list[0].vec<int32>();
 
     const Tensor* max_depth_t;
     OP_REQUIRES_OK(context, context->input("max_depth", &max_depth_t));
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
index 639a753..276e89a 100644
--- a/tensorflow/core/ops/boosted_trees_ops.cc
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -618,7 +618,7 @@
 
 REGISTER_OP("BoostedTreesUpdateEnsembleV2")
     .Input("tree_ensemble_handle: resource")
-    .Input("feature_ids: int32")
+    .Input("feature_ids: num_groups * int32")
     .Input("dimension_ids: num_features * int32")
     .Input("node_ids: num_features * int32")
     .Input("gains: num_features * float")
@@ -631,13 +631,18 @@
     .Input("pruning_mode: int32")
     .Attr("num_features: int >= 0")  // Inferred.
     .Attr("logits_dimension: int = 1")
+    .Attr("num_groups: int = 1")  // Number of groups to process.
     .SetShapeFn([](shape_inference::InferenceContext* c) {
       shape_inference::ShapeHandle shape_handle;
       int num_features;
       TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+      int num_groups;
+      TF_RETURN_IF_ERROR(c->GetAttr("num_groups", &num_groups));
 
       // Feature_ids, should be one for each feature.
       shape_inference::ShapeHandle feature_ids_shape;
+      // TODO(crawles): remove 1 hardcode once kernel operates on multiple
+      // groups.
       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &feature_ids_shape));
       TF_RETURN_IF_ERROR(
           c->Merge(c->input(1), c->Vector(num_features), &shape_handle));
diff --git a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
index 5e82fe4..fec912d9 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
@@ -180,7 +180,7 @@
           pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
           # Tree will be finalized now, since we will reach depth 1.
           max_depth=1,
-          feature_ids=feature_ids,
+          feature_ids=[feature_ids],
           dimension_ids=[feature1_dimensions, feature2_dimensions],
           node_ids=[feature1_nodes, feature2_nodes],
           gains=[feature1_gains, feature2_gains],
@@ -289,7 +289,7 @@
           pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
           # Tree will be finalized now, since we will reach depth 1.
           max_depth=1,
-          feature_ids=feature_ids,
+          feature_ids=[feature_ids],
           dimension_ids=[feature1_dimensions, feature2_dimensions],
           node_ids=[feature1_nodes, feature2_nodes],
           gains=[feature1_gains, feature2_gains],
@@ -401,7 +401,7 @@
           pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
           # Tree will be finalized now, since we will reach depth 1.
           max_depth=1,
-          feature_ids=feature_ids,
+          feature_ids=[feature_ids],
           dimension_ids=[feature1_dimensions, feature2_dimensions],
           node_ids=[feature1_nodes, feature2_nodes],
           gains=[feature1_gains, feature2_gains],
@@ -809,7 +809,7 @@
           pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
           # tree is going to be finalized now, since we reach depth 2.
           max_depth=2,
-          feature_ids=feature_ids,
+          feature_ids=[feature_ids],
           dimension_ids=[
               feature1_dimensions, feature2_dimensions, feature3_dimensions
           ],
@@ -1014,7 +1014,7 @@
           pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
           # tree is going to be finalized now, since we reach depth 2.
           max_depth=2,
-          feature_ids=feature_ids,
+          feature_ids=[feature_ids],
           dimension_ids=[
               feature1_dimensions, feature2_dimensions, feature3_dimensions
           ],
@@ -1230,7 +1230,7 @@
           pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
           # tree is going to be finalized now, since we reach depth 2.
           max_depth=2,
-          feature_ids=feature_ids,
+          feature_ids=[feature_ids],
           dimension_ids=[
               feature1_dimensions, feature2_dimensions, feature3_dimensions
           ],
@@ -1610,7 +1610,7 @@
           pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
           learning_rate=0.1,
           max_depth=2,
-          feature_ids=feature_ids,
+          feature_ids=[feature_ids],
           dimension_ids=[feature1_dimensions],
           node_ids=[feature1_nodes],
           gains=[feature1_gains],
@@ -1769,7 +1769,7 @@
           pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
           learning_rate=0.1,
           max_depth=2,
-          feature_ids=feature_ids,
+          feature_ids=[feature_ids],
           dimension_ids=[feature1_dimensions],
           node_ids=[feature1_nodes],
           gains=[feature1_gains],
@@ -1942,7 +1942,7 @@
           pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
           learning_rate=0.1,
           max_depth=2,
-          feature_ids=feature_ids,
+          feature_ids=[feature_ids],
           dimension_ids=[feature1_dimensions],
           node_ids=[feature1_nodes],
           gains=[feature1_gains],
@@ -2309,7 +2309,7 @@
           pruning_mode=boosted_trees_ops.PruningMode.PRE_PRUNING,
           # tree is going to be finalized now, since we reach depth 2.
           max_depth=3,
-          feature_ids=feature_ids,
+          feature_ids=[feature_ids],
           dimension_ids=[
               feature1_dimensions, feature2_dimensions, feature3_dimensions
           ],
@@ -3041,7 +3041,7 @@
           learning_rate=1.0,
           pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
           max_depth=3,
-          feature_ids=feature_ids,
+          feature_ids=[feature_ids],
           dimension_ids=[feature1_dimensions, feature2_dimensions],
           node_ids=[feature1_nodes, feature2_nodes],
           gains=[feature1_gains, feature2_gains],
@@ -3140,7 +3140,7 @@
           learning_rate=1.0,
           pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
           max_depth=3,
-          feature_ids=feature_ids,
+          feature_ids=[feature_ids],
           dimension_ids=[feature1_dimensions],
           node_ids=[feature1_nodes],
           gains=[feature1_gains],
@@ -3293,7 +3293,7 @@
           learning_rate=1.0,
           pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
           max_depth=3,
-          feature_ids=feature_ids,
+          feature_ids=[feature_ids],
           dimension_ids=[feature1_dimensions],
           node_ids=[feature1_nodes],
           gains=[feature1_gains],
@@ -3679,7 +3679,7 @@
           learning_rate=1.0,
           pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
           max_depth=2,
-          feature_ids=feature_ids,
+          feature_ids=[feature_ids],
           dimension_ids=[feature1_dimensions, feature2_dimensions],
           node_ids=[feature1_nodes, feature2_nodes],
           gains=[feature1_gains, feature2_gains],
@@ -3778,7 +3778,7 @@
           learning_rate=1.0,
           pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
           max_depth=2,
-          feature_ids=feature_ids,
+          feature_ids=[feature_ids],
           dimension_ids=[feature1_dimensions],
           node_ids=[feature1_nodes],
           gains=[feature1_gains],
@@ -4014,7 +4014,7 @@
           learning_rate=1.0,
           pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
           max_depth=1,
-          feature_ids=feature_ids,
+          feature_ids=[feature_ids],
           dimension_ids=[feature1_dimensions, feature2_dimensions],
           node_ids=[feature1_nodes, feature2_nodes],
           gains=[feature1_gains, feature2_gains],