| syntax = "proto3"; |
| |
| package tensorflow.boosted_trees; |
| |
| option cc_enable_arenas = true; |
| option java_outer_classname = "BoostedTreesProtos"; |
| option java_multiple_files = true; |
| option java_package = "org.tensorflow.framework"; |
| |
| // Node describes a node in a tree. |
| message Node { |
| oneof node { |
| Leaf leaf = 1; |
| BucketizedSplit bucketized_split = 2; |
| CategoricalSplit categorical_split = 3; |
| DenseSplit dense_split = 4; |
| } |
| NodeMetadata metadata = 777; |
| } |
| |
| // NodeMetadata encodes metadata associated with each node in a tree. |
| message NodeMetadata { |
| // The gain associated with this node. |
| float gain = 1; |
| |
| // The original leaf node before this node was split. |
| Leaf original_leaf = 2; |
| } |
| |
| // Leaves can either hold dense or sparse information. |
| message Leaf { |
| oneof leaf { |
| // See third_party/tensorflow/contrib/decision_trees/ |
| // proto/generic_tree_model.proto |
| // for a description of how vector and sparse_vector might be used. |
| Vector vector = 1; |
| SparseVector sparse_vector = 2; |
| } |
| float scalar = 3; |
| } |
| |
| message Vector { |
| repeated float value = 1; |
| } |
| |
| message SparseVector { |
| repeated int32 index = 1; |
| repeated float value = 2; |
| } |
| |
| message BucketizedSplit { |
| // Float feature column and split threshold describing |
| // the rule feature <= threshold. |
| int32 feature_id = 1; |
| int32 threshold = 2; |
| // If feature column is multivalent, this holds the index of the dimension |
| // for the split. Defaults to 0. |
| int32 dimension_id = 5; |
| enum DefaultDirection { |
| // Left is the default direction. |
| DEFAULT_LEFT = 0; |
| DEFAULT_RIGHT = 1; |
| } |
| // default direction for missing values. |
| DefaultDirection default_direction = 6; |
| |
| // Node children indexing into a contiguous |
| // vector of nodes starting from the root. |
| int32 left_id = 3; |
| int32 right_id = 4; |
| } |
| |
| message CategoricalSplit { |
| // Categorical feature column and split describing the rule feature value == |
| // value. |
| int32 feature_id = 1; |
| int32 value = 2; |
| |
| // Node children indexing into a contiguous |
| // vector of nodes starting from the root. |
| int32 left_id = 3; |
| int32 right_id = 4; |
| } |
| |
| // TODO(nponomareva): move out of boosted_trees and rename to trees.proto |
| message DenseSplit { |
| // Float feature column and split threshold describing |
| // the rule feature <= threshold. |
| int32 feature_id = 1; |
| float threshold = 2; |
| |
| // Node children indexing into a contiguous |
| // vector of nodes starting from the root. |
| int32 left_id = 3; |
| int32 right_id = 4; |
| } |
| |
| // Tree describes a list of connected nodes. |
| // Node 0 must be the root and can carry any payload including a leaf |
| // in the case of representing the bias. |
| // Note that each node id is implicitly its index in the list of nodes. |
| message Tree { |
| repeated Node nodes = 1; |
| } |
| |
| message TreeMetadata { |
| // Number of layers grown for this tree. |
| int32 num_layers_grown = 2; |
| |
| // Whether the tree is finalized in that no more layers can be grown. |
| bool is_finalized = 3; |
| |
| // If tree was finalized and post pruning happened, it is possible that cache |
| // still refers to some nodes that were deleted or that the node ids changed |
| // (e.g. node id 5 became node id 2 due to pruning of the other branch). |
| // The mapping below allows us to understand where the old ids now map to and |
| // how the values should be adjusted due to post-pruning. |
| // The size of the list should be equal to the number of nodes in the tree |
| // before post-pruning happened. |
| // If the node was pruned, it will have new_node_id equal to the id of a node |
| // that this node was collapsed into. For a node that didn't get pruned, it is |
| // possible that its id still changed, so new_node_id will have the |
| // corresponding id in the pruned tree. |
| // If post-pruning didn't happen, or it did and it had no effect (e.g. no |
| // nodes got pruned), this list will be empty. |
| repeated PostPruneNodeUpdate post_pruned_nodes_meta = 4; |
| |
| message PostPruneNodeUpdate { |
| int32 new_node_id = 1; |
| repeated float logit_change = 2; |
| } |
| } |
| |
| message GrowingMetadata { |
| // Number of trees that we have attempted to build. After pruning, these |
| // trees might have been removed. |
| int64 num_trees_attempted = 1; |
| // Number of layers that we have attempted to build. After pruning, these |
| // layers might have been removed. |
| int64 num_layers_attempted = 2; |
| // The start (inclusive) and end (exclusive) ids of the nodes in the latest |
| // layer of the latest tree. |
| int32 last_layer_node_start = 3; |
| int32 last_layer_node_end = 4; |
| } |
| |
| // TreeEnsemble describes an ensemble of decision trees. |
| message TreeEnsemble { |
| repeated Tree trees = 1; |
| repeated float tree_weights = 2; |
| |
| repeated TreeMetadata tree_metadata = 3; |
| // Metadata that is used during the training. |
| GrowingMetadata growing_metadata = 4; |
| } |
| |
| // DebugOutput contains outputs useful for debugging/model interpretation, at |
| // the individual example-level. Debug outputs that are available to the user |
| // are: 1) Directional feature contributions (DFCs) 2) Node IDs for ensemble |
| // prediction path 3) Leaf node IDs. |
| message DebugOutput { |
| // Return the logits and associated feature splits across prediction paths for |
| // each tree, for every example, at predict time. We will use these values to |
| // compute DFCs in Python, by subtracting each child prediction from its |
| // parent prediction and associating this change with its respective feature |
| // id. |
| repeated int32 feature_ids = 1; |
| repeated float logits_path = 2; |
| |
| // TODO(crawles): return 2) Node IDs for ensemble prediction path 3) Leaf node |
| // IDs. |
| } |