blob: 4e33bcdd6440177c3c085277260e9114cc046a0a [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
REGISTER_RESOURCE_HANDLE_OP(BoostedTreesEnsembleResource);
REGISTER_OP("IsBoostedTreesEnsembleInitialized")
.Input("tree_ensemble_handle: resource")
.Output("is_initialized: bool")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused_input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
c->set_output(0, c->Scalar());
return Status::OK();
});
REGISTER_OP("BoostedTreesCalculateBestGainsPerFeature")
.Input("node_id_range: int32")
.Input("stats_summary_list: num_features * float32")
.Input("l1: float")
.Input("l2: float")
.Input("tree_complexity: float")
.Input("min_node_weight: float")
.Attr("max_splits: int >= 1")
.Attr("num_features: int >= 1") // not passed but populated automatically.
.Output("node_ids_list: num_features * int32")
.Output("gains_list: num_features * float32")
.Output("thresholds_list: num_features * int32")
.Output("left_node_contribs_list: num_features * float32")
.Output("right_node_contribs_list: num_features * float32")
.SetShapeFn([](shape_inference::InferenceContext* c) {
// Confirms the rank of the inputs and sets the shape of the outputs.
int max_splits;
int num_features;
TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits));
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
shape_inference::ShapeHandle node_id_range_shape;
shape_inference::ShapeHandle unused_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape));
TF_RETURN_IF_ERROR(
c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape));
// Checks that all stats summary entries are of the same shape.
shape_inference::ShapeHandle summary_shape_base;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &summary_shape_base));
TF_RETURN_IF_ERROR(c->Merge(summary_shape_base,
c->MakeShape({max_splits, -1, 2}),
&unused_shape));
for (int i = 1; i < num_features; ++i) {
shape_inference::ShapeHandle summary_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1 + i), 3, &summary_shape));
TF_RETURN_IF_ERROR(
c->Merge(summary_shape_base, summary_shape, &unused_shape));
}
TF_RETURN_IF_ERROR(
c->WithRank(c->input(num_features + 1), 0, &unused_shape));
TF_RETURN_IF_ERROR(
c->WithRank(c->input(num_features + 2), 0, &unused_shape));
TF_RETURN_IF_ERROR(
c->WithRank(c->input(num_features + 3), 0, &unused_shape));
// Sets the output lists.
std::vector<shape_inference::ShapeHandle> output_shapes_vec(
num_features, c->MakeShape({-1}));
TF_RETURN_IF_ERROR(c->set_output("node_ids_list", output_shapes_vec));
TF_RETURN_IF_ERROR(c->set_output("gains_list", output_shapes_vec));
TF_RETURN_IF_ERROR(c->set_output("thresholds_list", output_shapes_vec));
std::vector<shape_inference::ShapeHandle> output_shapes_contribs(
num_features, c->MakeShape({-1, 1}));
TF_RETURN_IF_ERROR(
c->set_output("left_node_contribs_list", output_shapes_contribs));
TF_RETURN_IF_ERROR(
c->set_output("right_node_contribs_list", output_shapes_contribs));
return Status::OK();
});
REGISTER_OP("BoostedTreesCalculateBestFeatureSplit")
.Input("node_id_range: int32")
.Input("stats_summary: float32")
.Input("l1: float")
.Input("l2: float")
.Input("tree_complexity: float")
.Input("min_node_weight: float")
.Attr("logits_dimension: int >= 1")
.Attr("split_type: {'inequality'} = 'inequality'")
.Output("node_ids: int32")
.Output("gains: float32")
.Output("feature_dimensions: int32")
.Output("thresholds: int32")
.Output("left_node_contribs: float32")
.Output("right_node_contribs: float32")
.Output("split_with_default_directions: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle node_id_range_shape;
shape_inference::ShapeHandle unused_shape;
// node id range is rank 1 with 2 values.
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape));
TF_RETURN_IF_ERROR(
c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
ShapeHandle output_shape = c->MakeShape({c->UnknownDim()});
for (int i = 0; i < 7; ++i) {
c->set_output(i, output_shape);
}
return Status::OK();
});
REGISTER_OP("BoostedTreesSparseCalculateBestFeatureSplit")
.Input("node_id_range: int32")
.Input("stats_summary_indices: int32")
.Input("stats_summary_values: float")
.Input("stats_summary_shape: int32")
.Input("l1: float")
.Input("l2: float")
.Input("tree_complexity: float")
.Input("min_node_weight: float")
.Attr("logits_dimension: int >= 1")
.Attr("split_type: {'inequality'} = 'inequality'")
.Output("node_ids: int32")
.Output("gains: float32")
.Output("feature_dimensions: int32")
.Output("thresholds: int32")
.Output("left_node_contribs: float32")
.Output("right_node_contribs: float32")
.Output("split_with_default_directions: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle node_id_range_shape;
shape_inference::ShapeHandle unused_shape;
// node id range is rank 1 with 2 values.
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape));
TF_RETURN_IF_ERROR(
c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused_shape));
shape_inference::ShapeHandle summary_shape;
TF_RETURN_IF_ERROR(
c->Merge(summary_shape, c->MakeShape({4}), &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused_shape));
ShapeHandle output_shape = c->MakeShape({-1});
for (int i = 0; i < 7; ++i) {
c->set_output(i, output_shape);
}
return Status::OK();
});
REGISTER_OP("BoostedTreesCreateEnsemble")
.Input("tree_ensemble_handle: resource")
.Input("stamp_token: int64")
.Input("tree_ensemble_serialized: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused_input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
return Status::OK();
});
REGISTER_OP("BoostedTreesDeserializeEnsemble")
.Input("tree_ensemble_handle: resource")
.Input("stamp_token: int64")
.Input("tree_ensemble_serialized: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused_input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
return Status::OK();
});
REGISTER_OP("BoostedTreesGetEnsembleStates")
.Input("tree_ensemble_handle: resource")
.Output("stamp_token: int64")
.Output("num_trees: int32")
.Output("num_finalized_trees: int32")
.Output("num_attempted_layers: int32")
.Output("last_layer_nodes_range: int32")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused_input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
c->set_output(0, c->Scalar());
c->set_output(1, c->Scalar());
c->set_output(2, c->Scalar());
c->set_output(3, c->Scalar());
c->set_output(4, c->Vector(2));
return Status::OK();
});
REGISTER_OP("BoostedTreesMakeStatsSummary")
.Input("node_ids: int32")
.Input("gradients: float")
.Input("hessians: float")
.Input("bucketized_features_list: num_features * int32")
.Attr("max_splits: int >= 1")
.Attr("num_buckets: int >= 1")
.Attr("num_features: int >= 1")
.Output("stats_summary: float")
.SetShapeFn([](shape_inference::InferenceContext* c) {
// Sets the shape of the output as a Rank 4 Tensor.
int max_splits;
int num_buckets;
int num_features;
TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits));
TF_RETURN_IF_ERROR(c->GetAttr("num_buckets", &num_buckets));
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
shape_inference::ShapeHandle node_ids_shape;
shape_inference::ShapeHandle gradients_shape;
shape_inference::ShapeHandle hessians_shape;
shape_inference::ShapeHandle bucketized_feature_shape;
shape_inference::ShapeHandle unused_shape;
shape_inference::DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_ids_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape));
TF_RETURN_IF_ERROR(c->Merge(c->Dim(node_ids_shape, 0),
c->Dim(gradients_shape, 0), &unused_dim));
TF_RETURN_IF_ERROR(
c->Merge(gradients_shape, hessians_shape, &unused_shape));
for (int f = 0; f < num_features; ++f) {
TF_RETURN_IF_ERROR(
c->WithRank(c->input(3 + f), 1, &bucketized_feature_shape));
TF_RETURN_IF_ERROR(c->Merge(c->Dim(node_ids_shape, 0),
c->Dim(bucketized_feature_shape, 0),
&unused_dim));
}
c->set_output(0,
c->MakeShape({num_features, max_splits, num_buckets, 2}));
return Status::OK();
});
// V2 of BoostedTreesMakeStatsSummary. Supports multi-dim dense Tensor and
// multi class.
REGISTER_OP("BoostedTreesAggregateStats")
.Input("node_ids: int32")
.Input("gradients: float")
.Input("hessians: float")
.Input("feature: int32")
.Attr("max_splits: int >= 1")
.Attr("num_buckets: int >= 1")
.Output("stats_summary: float")
.SetShapeFn([](shape_inference::InferenceContext* c) {
// Sets the shape of the output as a Rank 4 Tensor.
int max_splits;
int num_buckets;
TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits));
TF_RETURN_IF_ERROR(c->GetAttr("num_buckets", &num_buckets));
shape_inference::ShapeHandle node_ids_shape;
shape_inference::ShapeHandle gradients_shape;
shape_inference::ShapeHandle hessians_shape;
shape_inference::ShapeHandle feature_shape;
shape_inference::DimensionHandle batch_size = c->Dim(c->input(0), 0);
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_ids_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &feature_shape));
// Verify all three inputs have same first dimension, i.e., batch_size.
TF_RETURN_IF_ERROR(c->Merge(c->Dim(gradients_shape, 0),
c->Dim(node_ids_shape, 0), &batch_size));
TF_RETURN_IF_ERROR(c->Merge(c->Dim(hessians_shape, 0),
c->Dim(node_ids_shape, 0), &batch_size));
TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
c->Dim(node_ids_shape, 0), &batch_size));
DimensionHandle logits_dim = c->Dim(c->input(1), 1);
DimensionHandle hessian_dim = c->Dim(c->input(2), 1);
DimensionHandle feature_dim = c->Dim(c->input(3), 1);
DimensionHandle stats_dim;
TF_RETURN_IF_ERROR(c->Add(logits_dim, hessian_dim, &stats_dim));
c->set_output(
0, c->MakeShape({max_splits, num_buckets, feature_dim, stats_dim}));
return Status::OK();
});
// Sparse Version of BoostedTreesAggregatesStats.
REGISTER_OP("BoostedTreesSparseAggregateStats")
.Input("node_ids: int32")
.Input("gradients: float")
.Input("hessians: float")
.Input("feature_indices: int32")
.Input("feature_values: int32")
.Input("feature_shape: int32")
.Attr("max_splits: int >= 1")
.Attr("num_buckets: int >= 1")
.Output("stats_summary_indices: int32")
.Output("stats_summary_values: float")
.Output("stats_summary_shape: int32")
.SetShapeFn([](shape_inference::InferenceContext* c) {
int max_splits;
int num_buckets;
TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits));
TF_RETURN_IF_ERROR(c->GetAttr("num_buckets", &num_buckets));
shape_inference::ShapeHandle node_ids_shape;
shape_inference::ShapeHandle gradients_shape;
shape_inference::ShapeHandle hessians_shape;
shape_inference::ShapeHandle feature_indices_shape;
shape_inference::ShapeHandle feature_values_shape;
shape_inference::ShapeHandle feature_shape;
shape_inference::DimensionHandle batch_size = c->Dim(c->input(0), 0);
shape_inference::DimensionHandle num_entries;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_ids_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &feature_indices_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &feature_values_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &feature_shape));
// Verify all inputs have same first dimension, i.e., batch_size.
TF_RETURN_IF_ERROR(c->Merge(c->Dim(gradients_shape, 0),
c->Dim(node_ids_shape, 0), &batch_size));
TF_RETURN_IF_ERROR(c->Merge(c->Dim(hessians_shape, 0),
c->Dim(node_ids_shape, 0), &batch_size));
TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_indices_shape, 0),
c->Dim(feature_values_shape, 0),
&num_entries));
DimensionHandle unused;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(feature_shape, 0), 2, &unused));
DimensionHandle logits_dim = c->Dim(c->input(1), 1);
DimensionHandle hessian_dim = c->Dim(c->input(2), 1);
DimensionHandle stats_dim;
TF_RETURN_IF_ERROR(c->Add(logits_dim, hessian_dim, &stats_dim));
c->set_output(0, c->MakeShape({c->UnknownDim(), 4}));
c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
c->set_output(2, c->MakeShape({4}));
return Status::OK();
});
// TODO(nponomareva): when/if creating the new op for unbucketized data, rename
// bucketized_features to features.
REGISTER_OP("BoostedTreesPredict")
.Input("tree_ensemble_handle: resource")
.Input("bucketized_features: num_bucketized_features * int32")
.Attr("num_bucketized_features: int >= 1") // Inferred.
.Attr("logits_dimension: int")
.Output("logits: float")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle feature_shape;
int num_bucketized_features;
TF_RETURN_IF_ERROR(
c->GetAttr("num_bucketized_features", &num_bucketized_features));
shape_inference::ShapeHandle unused_input;
for (int i = 0; i < num_bucketized_features; ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 1), 1, &feature_shape));
// Check that the shapes of all bucketized features are the same.
TF_RETURN_IF_ERROR(c->Merge(c->input(1), feature_shape, &unused_input));
}
int logits_dimension;
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
auto logits_shape =
c->MakeShape({c->Dim(feature_shape, 0), logits_dimension});
// Logits.
c->set_output(0, logits_shape);
return Status::OK();
});
REGISTER_OP("BoostedTreesExampleDebugOutputs")
.Input("tree_ensemble_handle: resource")
.Input("bucketized_features: num_bucketized_features * int32")
.Attr("num_bucketized_features: int >= 1") // Inferred.
.Attr("logits_dimension: int")
.Output("examples_debug_outputs_serialized: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle feature_shape;
int num_bucketized_features;
TF_RETURN_IF_ERROR(
c->GetAttr("num_bucketized_features", &num_bucketized_features));
shape_inference::ShapeHandle unused_input;
for (int i = 0; i < num_bucketized_features; ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 1), 1, &feature_shape));
// Check that the shapes of all bucketized features are the same.
TF_RETURN_IF_ERROR(c->Merge(c->input(1), feature_shape, &unused_input));
}
// Multi-class will be supported by modifying the proto.
auto batch_size = c->MakeShape({c->Dim(feature_shape, 0)});
c->set_output(0, batch_size);
return Status::OK();
});
REGISTER_OP("BoostedTreesSerializeEnsemble")
.Input("tree_ensemble_handle: resource")
.Output("stamp_token: int64")
.Output("tree_ensemble_serialized: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused_input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
c->set_output(0, c->Scalar());
c->set_output(1, c->Scalar());
return Status::OK();
});
REGISTER_OP("BoostedTreesTrainingPredict")
.Input("tree_ensemble_handle: resource")
.Input("cached_tree_ids: int32")
.Input("cached_node_ids: int32")
.Input("bucketized_features: num_bucketized_features * int32")
.Attr("num_bucketized_features: int >= 1")
.Attr("logits_dimension: int")
.Output("partial_logits: float")
.Output("tree_ids: int32")
.Output("node_ids: int32")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle feature_shape;
int num_bucketized_features;
TF_RETURN_IF_ERROR(
c->GetAttr("num_bucketized_features", &num_bucketized_features));
shape_inference::ShapeHandle unused_input;
for (int i = 0; i < num_bucketized_features; ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 3), 1, &feature_shape));
TF_RETURN_IF_ERROR(
c->Merge(c->input(i + 3), feature_shape, &unused_input));
}
// all inputs/outputs except logits should have same shape.
TF_RETURN_IF_ERROR(c->Merge(c->input(1), feature_shape, &unused_input));
TF_RETURN_IF_ERROR(c->Merge(c->input(2), feature_shape, &unused_input));
int logits_dimension;
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
auto logits_shape =
c->MakeShape({c->Dim(feature_shape, 0), logits_dimension});
// Partial logits.
c->set_output(0, logits_shape);
// Tree ids.
c->set_output(1, c->MakeShape({c->Dim(feature_shape, 0)}));
// Node ids.
c->set_output(2, c->MakeShape({c->Dim(feature_shape, 0)}));
return Status::OK();
});
REGISTER_OP("BoostedTreesUpdateEnsemble")
.Input("tree_ensemble_handle: resource")
.Input("feature_ids: int32")
.Input("node_ids: num_features * int32")
.Input("gains: num_features * float")
.Input("thresholds: num_features * int32")
.Input("left_node_contribs: num_features * float")
.Input("right_node_contribs: num_features * float")
.Input("max_depth: int32")
.Input("learning_rate: float")
.Attr("pruning_mode: int >=0")
.Attr("num_features: int >= 0") // Inferred.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle shape_handle;
int num_features;
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
// Feature_ids, should be one for each feature.
shape_inference::ShapeHandle feature_ids_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &feature_ids_shape));
TF_RETURN_IF_ERROR(
c->Merge(c->input(1), c->Vector(num_features), &shape_handle));
for (int i = 0; i < num_features; ++i) {
// Node ids.
TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2), 1, &shape_handle));
auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)});
auto shape_rank_2 = c->MakeShape({c->Dim(shape_handle, 0), 1});
// Gains.
TF_RETURN_IF_ERROR(
c->WithRank(c->input(i + num_features + 2), 1, &shape_handle));
// TODO(nponomareva): replace this with input("name",vector of shapes).
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features + 2),
shape_rank_1, &shape_handle));
// Thresholds.
TF_RETURN_IF_ERROR(
c->WithRank(c->input(i + num_features * 2 + 2), 1, &shape_handle));
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 2 + 2),
shape_rank_1, &shape_handle));
// Left and right node contribs.
TF_RETURN_IF_ERROR(
c->WithRank(c->input(i + num_features * 3 + 2), 2, &shape_handle));
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 3 + 2),
shape_rank_2, &shape_handle));
TF_RETURN_IF_ERROR(
c->WithRank(c->input(i + num_features * 4 + 2), 2, &shape_handle));
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 4 + 2),
shape_rank_2, &shape_handle));
}
return Status::OK();
});
REGISTER_OP("BoostedTreesCenterBias")
.Input("tree_ensemble_handle: resource")
.Input("mean_gradients: float")
.Input("mean_hessians: float")
// Regularization-related.
.Input("l1: float")
.Input("l2: float")
.Output("continue_centering: bool")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle gradients_shape;
shape_inference::ShapeHandle hessians_shape;
shape_inference::ShapeHandle unused_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape));
TF_RETURN_IF_ERROR(
c->Merge(gradients_shape, hessians_shape, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
c->set_output(0, c->Scalar());
return Status::OK();
});
REGISTER_RESOURCE_HANDLE_OP(BoostedTreesQuantileStreamResource);
REGISTER_OP("IsBoostedTreesQuantileStreamResourceInitialized")
.Input("quantile_stream_resource_handle: resource")
.Output("is_initialized: bool")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused_input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
c->set_output(0, c->Scalar());
return Status::OK();
});
REGISTER_OP("BoostedTreesCreateQuantileStreamResource")
.Attr("max_elements: int = 1099511627776") // 1 << 40
.Input("quantile_stream_resource_handle: resource")
.Input("epsilon: float")
.Input("num_streams: int64")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused_input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
return Status::OK();
});
REGISTER_OP("BoostedTreesMakeQuantileSummaries")
.Attr("num_features: int >= 0")
.Input("float_values: num_features * float")
.Input("example_weights: float")
.Input("epsilon: float")
.Output("summaries: num_features * float")
.SetShapeFn([](InferenceContext* c) {
int num_features;
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
ShapeHandle example_weights_shape;
TF_RETURN_IF_ERROR(
c->WithRank(c->input(num_features), 1, &example_weights_shape));
for (int i = 0; i < num_features; ++i) {
ShapeHandle feature_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &feature_shape));
// the columns are value, weight, min_rank, max_rank.
c->set_output(i, c->MakeShape({c->UnknownDim(), 4}));
}
// epsilon must be a scalar.
ShapeHandle unused_input;
TF_RETURN_IF_ERROR(
c->WithRank(c->input(num_features + 1), 0, &unused_input));
return Status::OK();
});
REGISTER_OP("BoostedTreesFlushQuantileSummaries")
.Attr("num_features: int >= 0")
.Input("quantile_stream_resource_handle: resource")
.Output("summaries: num_features * float")
.SetShapeFn([](InferenceContext* c) {
int num_features;
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
for (int i = 0; i < num_features; ++i) {
// the columns are value, weight, min_rank, max_rank.
c->set_output(i, c->MakeShape({c->UnknownDim(), 4}));
}
return Status::OK();
});
REGISTER_OP("BoostedTreesQuantileStreamResourceAddSummaries")
.Attr("num_features: int >= 0")
.Input("quantile_stream_resource_handle: resource")
.Input("summaries: num_features * float")
.SetShapeFn([](InferenceContext* c) {
int num_features;
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
// resource handle must be a scalar.
shape_inference::ShapeHandle unused_input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
// each summary must be rank 2.
for (int i = 1; i < num_features + 1; i++) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &unused_input));
}
return Status::OK();
});
REGISTER_OP("BoostedTreesQuantileStreamResourceDeserialize")
.Attr("num_streams: int")
.Input("quantile_stream_resource_handle: resource")
.Input("bucket_boundaries: num_streams * float")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused_input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
return Status::OK();
});
REGISTER_OP("BoostedTreesQuantileStreamResourceFlush")
.Attr("generate_quantiles: bool = False")
.Input("quantile_stream_resource_handle: resource")
.Input("num_buckets: int64")
.SetShapeFn([](InferenceContext* c) {
// All the inputs are scalars.
shape_inference::ShapeHandle unused_input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
return Status::OK();
});
REGISTER_OP("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
.Attr("num_features: int >= 0")
.Input("quantile_stream_resource_handle: resource")
.Output("bucket_boundaries: num_features * float")
.SetShapeFn([](InferenceContext* c) {
int num_features;
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
shape_inference::ShapeHandle unused_input;
// resource handle must be a scalar.
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
for (int i = 0; i < num_features; i++) {
c->set_output(i, c->Vector(c->UnknownDim()));
}
return Status::OK();
});
REGISTER_OP("BoostedTreesBucketize")
.Attr("num_features: int >= 0")
.Input("float_values: num_features * float")
.Input("bucket_boundaries: num_features * float")
.Output("buckets: num_features * int32")
.SetShapeFn([](InferenceContext* c) {
int num_features;
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
ShapeHandle feature_shape;
DimensionHandle unused_dim;
for (int i = 0; i < num_features; i++) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &feature_shape));
TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
c->Dim(c->input(0), 0), &unused_dim));
}
// Bucketized result should have same dimension as input.
for (int i = 0; i < num_features; i++) {
c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0)}));
}
return Status::OK();
});
} // namespace tensorflow