blob: 4e0f4c7d56cc18007caab7ce13381e5eec3337a0 [file] [log] [blame]
syntax = "proto3";
package tensorflow.boosted_trees;
option cc_enable_arenas = true;
option java_outer_classname = "BoostedTreesProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
// Node describes a node in a tree.
message Node {
oneof node {
Leaf leaf = 1;
BucketizedSplit bucketized_split = 2;
CategoricalSplit categorical_split = 3;
DenseSplit dense_split = 4;
NodeMetadata metadata = 777;
// NodeMetadata encodes metadata associated with each node in a tree.
message NodeMetadata {
// The gain associated with this node.
float gain = 1;
// The original leaf node before this node was split.
Leaf original_leaf = 2;
// Leaves can either hold dense or sparse information.
message Leaf {
oneof leaf {
// See third_party/tensorflow/contrib/decision_trees/
// proto/generic_tree_model.proto
// for a description of how vector and sparse_vector might be used.
Vector vector = 1;
SparseVector sparse_vector = 2;
float scalar = 3;
message Vector {
repeated float value = 1;
message SparseVector {
repeated int32 index = 1;
repeated float value = 2;
message BucketizedSplit {
// Float feature column and split threshold describing
// the rule feature <= threshold.
int32 feature_id = 1;
int32 threshold = 2;
// If feature column is multivalent, this holds the index of the dimension
// for the split. Defaults to 0.
int32 dimension_id = 5;
enum DefaultDirection {
// Left is the default direction.
// default direction for missing values.
DefaultDirection default_direction = 6;
// Node children indexing into a contiguous
// vector of nodes starting from the root.
int32 left_id = 3;
int32 right_id = 4;
message CategoricalSplit {
// Categorical feature column and split describing the rule feature value ==
// value.
int32 feature_id = 1;
int32 value = 2;
// Node children indexing into a contiguous
// vector of nodes starting from the root.
int32 left_id = 3;
int32 right_id = 4;
// TODO(nponomareva): move out of boosted_trees and rename to trees.proto
message DenseSplit {
// Float feature column and split threshold describing
// the rule feature <= threshold.
int32 feature_id = 1;
float threshold = 2;
// Node children indexing into a contiguous
// vector of nodes starting from the root.
int32 left_id = 3;
int32 right_id = 4;
// Tree describes a list of connected nodes.
// Node 0 must be the root and can carry any payload including a leaf
// in the case of representing the bias.
// Note that each node id is implicitly its index in the list of nodes.
message Tree {
repeated Node nodes = 1;
message TreeMetadata {
// Number of layers grown for this tree.
int32 num_layers_grown = 2;
// Whether the tree is finalized in that no more layers can be grown.
bool is_finalized = 3;
// If tree was finalized and post pruning happened, it is possible that cache
// still refers to some nodes that were deleted or that the node ids changed
// (e.g. node id 5 became node id 2 due to pruning of the other branch).
// The mapping below allows us to understand where the old ids now map to and
// how the values should be adjusted due to post-pruning.
// The size of the list should be equal to the number of nodes in the tree
// before post-pruning happened.
// If the node was pruned, it will have new_node_id equal to the id of a node
// that this node was collapsed into. For a node that didn't get pruned, it is
// possible that its id still changed, so new_node_id will have the
// corresponding id in the pruned tree.
// If post-pruning didn't happen, or it did and it had no effect (e.g. no
// nodes got pruned), this list will be empty.
repeated PostPruneNodeUpdate post_pruned_nodes_meta = 4;
message PostPruneNodeUpdate {
int32 new_node_id = 1;
repeated float logit_change = 2;
message GrowingMetadata {
// Number of trees that we have attempted to build. After pruning, these
// trees might have been removed.
int64 num_trees_attempted = 1;
// Number of layers that we have attempted to build. After pruning, these
// layers might have been removed.
int64 num_layers_attempted = 2;
// The start (inclusive) and end (exclusive) ids of the nodes in the latest
// layer of the latest tree.
int32 last_layer_node_start = 3;
int32 last_layer_node_end = 4;
// TreeEnsemble describes an ensemble of decision trees.
message TreeEnsemble {
repeated Tree trees = 1;
repeated float tree_weights = 2;
repeated TreeMetadata tree_metadata = 3;
// Metadata that is used during the training.
GrowingMetadata growing_metadata = 4;
// DebugOutput contains outputs useful for debugging/model interpretation, at
// the individual example-level. Debug outputs that are available to the user
// are: 1) Directional feature contributions (DFCs) 2) Node IDs for ensemble
// prediction path 3) Leaf node IDs.
message DebugOutput {
// Return the logits and associated feature splits across prediction paths for
// each tree, for every example, at predict time. We will use these values to
// compute DFCs in Python, by subtracting each child prediction from its
// parent prediction and associating this change with its respective feature
// id.
repeated int32 feature_ids = 1;
repeated float logits_path = 2;
// TODO(crawles): return 2) Node IDs for ensemble prediction path 3) Leaf node
// IDs.