blob: 7cd62af3a9580f6675386b93a52cfb927e9e5a10 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <string>
#include <vector>
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
#include "tensorflow/core/kernels/boosted_trees/resources.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
// The Op used during training time to get the predictions so far with the
// current ensemble being built.
// Expect some logits are cached from the previous step and passed through
// to be reused.
class BoostedTreesTrainingPredictOp : public OpKernel {
public:
explicit BoostedTreesTrainingPredictOp(OpKernelConstruction* const context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
&num_bucketized_features_));
OP_REQUIRES_OK(context,
context->GetAttr("logits_dimension", &logits_dimension_));
}
void Compute(OpKernelContext* const context) override {
core::RefCountPtr<BoostedTreesEnsembleResource> resource;
// Get the resource.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&resource));
// Get the inputs.
OpInputList bucketized_features_list;
OP_REQUIRES_OK(context, context->input_list("bucketized_features",
&bucketized_features_list));
std::vector<tensorflow::TTypes<int32>::ConstVec> batch_bucketized_features;
batch_bucketized_features.reserve(bucketized_features_list.size());
for (const Tensor& tensor : bucketized_features_list) {
batch_bucketized_features.emplace_back(tensor.vec<int32>());
}
const int batch_size = batch_bucketized_features[0].size();
const Tensor* cached_tree_ids_t;
OP_REQUIRES_OK(context,
context->input("cached_tree_ids", &cached_tree_ids_t));
const auto cached_tree_ids = cached_tree_ids_t->vec<int32>();
const Tensor* cached_node_ids_t;
OP_REQUIRES_OK(context,
context->input("cached_node_ids", &cached_node_ids_t));
const auto cached_node_ids = cached_node_ids_t->vec<int32>();
// Allocate outputs.
Tensor* output_partial_logits_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("partial_logits",
{batch_size, logits_dimension_},
&output_partial_logits_t));
auto output_partial_logits = output_partial_logits_t->matrix<float>();
Tensor* output_tree_ids_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output("tree_ids", {batch_size},
&output_tree_ids_t));
auto output_tree_ids = output_tree_ids_t->vec<int32>();
Tensor* output_node_ids_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output("node_ids", {batch_size},
&output_node_ids_t));
auto output_node_ids = output_node_ids_t->vec<int32>();
// Indicate that the latest tree was used.
const int32 latest_tree = resource->num_trees() - 1;
if (latest_tree < 0) {
// Ensemble was empty. Output the very first node.
output_node_ids.setZero();
output_tree_ids = cached_tree_ids;
// All the predictions are zeros.
output_partial_logits.setZero();
} else {
output_tree_ids.setConstant(latest_tree);
auto do_work = [&resource, &batch_bucketized_features, &cached_tree_ids,
&cached_node_ids, &output_partial_logits,
&output_node_ids, latest_tree,
this](int32 start, int32 end) {
for (int32 i = start; i < end; ++i) {
int32 tree_id = cached_tree_ids(i);
int32 node_id = cached_node_ids(i);
std::vector<float> partial_tree_logits(logits_dimension_, 0.0);
if (node_id >= 0) {
// If the tree was pruned, returns the node id into which the
// current_node_id was pruned, as well the correction of the cached
// logit prediction.
resource->GetPostPruneCorrection(tree_id, node_id, &node_id,
&partial_tree_logits);
// Logic in the loop adds the cached node value again if it is a
// leaf. If it is not a leaf anymore we need to subtract the old
// node's value. The following logic handles both of these cases.
const auto& node_logits = resource->node_value(tree_id, node_id);
if (!node_logits.empty()) {
DCHECK_EQ(node_logits.size(), logits_dimension_);
for (int32 j = 0; j < logits_dimension_; ++j) {
partial_tree_logits[j] -= node_logits[j];
}
}
} else {
// No cache exists, start from the very first node.
node_id = 0;
}
std::vector<float> partial_all_logits(logits_dimension_, 0.0);
while (true) {
if (resource->is_leaf(tree_id, node_id)) {
const auto& leaf_logits = resource->node_value(tree_id, node_id);
DCHECK_EQ(leaf_logits.size(), logits_dimension_);
// Tree is done
const float tree_weight = resource->GetTreeWeight(tree_id);
for (int32 j = 0; j < logits_dimension_; ++j) {
partial_all_logits[j] +=
tree_weight * (partial_tree_logits[j] + leaf_logits[j]);
partial_tree_logits[j] = 0;
}
// Stop if it was the latest tree.
if (tree_id == latest_tree) {
break;
}
// Move onto other trees.
++tree_id;
node_id = 0;
} else {
node_id = resource->next_node(tree_id, node_id, i,
batch_bucketized_features);
}
}
output_node_ids(i) = node_id;
for (int32 j = 0; j < logits_dimension_; ++j) {
output_partial_logits(i, j) = partial_all_logits[j];
}
}
};
// 30 is the magic number. The actual value might be a function of (the
// number of layers) * (cpu cycles spent on each layer), but this value
// would work for many cases. May be tuned later.
const int64 cost = 30;
thread::ThreadPool* const worker_threads =
context->device()->tensorflow_cpu_worker_threads()->workers;
Shard(worker_threads->NumThreads(), worker_threads, batch_size,
/*cost_per_unit=*/cost, do_work);
}
}
private:
int32 logits_dimension_; // the size of the output prediction vector.
int32 num_bucketized_features_; // Indicates the number of features.
};
REGISTER_KERNEL_BUILDER(Name("BoostedTreesTrainingPredict").Device(DEVICE_CPU),
BoostedTreesTrainingPredictOp);
// The Op to get the predictions at the evaluation/inference time.
class BoostedTreesPredictOp : public OpKernel {
public:
explicit BoostedTreesPredictOp(OpKernelConstruction* const context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
&num_bucketized_features_));
OP_REQUIRES_OK(context,
context->GetAttr("logits_dimension", &logits_dimension_));
}
void Compute(OpKernelContext* const context) override {
core::RefCountPtr<BoostedTreesEnsembleResource> resource;
// Get the resource.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&resource));
// Get the inputs.
OpInputList bucketized_features_list;
OP_REQUIRES_OK(context, context->input_list("bucketized_features",
&bucketized_features_list));
std::vector<tensorflow::TTypes<int32>::ConstVec> batch_bucketized_features;
batch_bucketized_features.reserve(bucketized_features_list.size());
for (const Tensor& tensor : bucketized_features_list) {
batch_bucketized_features.emplace_back(tensor.vec<int32>());
}
const int batch_size = batch_bucketized_features[0].size();
// Allocate outputs.
Tensor* output_logits_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
"logits", {batch_size, logits_dimension_},
&output_logits_t));
auto output_logits = output_logits_t->matrix<float>();
// Return zero logits if it's an empty ensemble.
if (resource->num_trees() <= 0) {
output_logits.setZero();
return;
}
const int32 last_tree = resource->num_trees() - 1;
auto do_work = [&resource, &batch_bucketized_features, &output_logits,
last_tree, this](int32 start, int32 end) {
for (int32 i = start; i < end; ++i) {
std::vector<float> tree_logits(logits_dimension_, 0.0);
int32 tree_id = 0;
int32 node_id = 0;
while (true) {
if (resource->is_leaf(tree_id, node_id)) {
const float tree_weight = resource->GetTreeWeight(tree_id);
const auto& leaf_logits = resource->node_value(tree_id, node_id);
DCHECK_EQ(leaf_logits.size(), logits_dimension_);
for (int32 j = 0; j < logits_dimension_; ++j) {
tree_logits[j] += tree_weight * leaf_logits[j];
}
// Stop if it was the last tree.
if (tree_id == last_tree) {
break;
}
// Move onto other trees.
++tree_id;
node_id = 0;
} else {
node_id = resource->next_node(tree_id, node_id, i,
batch_bucketized_features);
}
}
for (int32 j = 0; j < logits_dimension_; ++j) {
output_logits(i, j) = tree_logits[j];
}
}
};
// 10 is the magic number. The actual number might depend on (the number of
// layers in the trees) and (cpu cycles spent on each layer), but this
// value would work for many cases. May be tuned later.
const int64 cost = (last_tree + 1) * 10;
thread::ThreadPool* const worker_threads =
context->device()->tensorflow_cpu_worker_threads()->workers;
Shard(worker_threads->NumThreads(), worker_threads, batch_size,
/*cost_per_unit=*/cost, do_work);
}
private:
int32
logits_dimension_; // Indicates the size of the output prediction vector.
int32 num_bucketized_features_; // Indicates the number of features.
};
REGISTER_KERNEL_BUILDER(Name("BoostedTreesPredict").Device(DEVICE_CPU),
BoostedTreesPredictOp);
// The Op that returns debugging/model interpretability outputs for each
// example. Currently it outputs the split feature ids and logits after each
// split along the decision path for each example. This will be used to compute
// directional feature contributions at predict time for an arbitrary activation
// function.
// TODO(crawles): return in proto 1) Node IDs for ensemble prediction path
// 2) Leaf node IDs.
class BoostedTreesExampleDebugOutputsOp : public OpKernel {
public:
explicit BoostedTreesExampleDebugOutputsOp(
OpKernelConstruction* const context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
&num_bucketized_features_));
OP_REQUIRES_OK(context,
context->GetAttr("logits_dimension", &logits_dimension_));
OP_REQUIRES(context, logits_dimension_ == 1,
errors::InvalidArgument(
"Currently only one dimensional outputs are supported."));
}
void Compute(OpKernelContext* const context) override {
core::RefCountPtr<BoostedTreesEnsembleResource> resource;
// Get the resource.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&resource));
// Get the inputs.
OpInputList bucketized_features_list;
OP_REQUIRES_OK(context, context->input_list("bucketized_features",
&bucketized_features_list));
std::vector<tensorflow::TTypes<int32>::ConstVec> batch_bucketized_features;
batch_bucketized_features.reserve(bucketized_features_list.size());
for (const Tensor& tensor : bucketized_features_list) {
batch_bucketized_features.emplace_back(tensor.vec<int32>());
}
const int batch_size = batch_bucketized_features[0].size();
// We need to get the feature ids used for splitting and the logits after
// each split. We will use these to calculate the changes in the prediction
// (contributions) for an arbitrary activation function (done in Python) and
// attribute them to the associated feature ids. We will store these in
// a proto below.
Tensor* output_debug_info_t = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output("examples_debug_outputs_serialized",
{batch_size}, &output_debug_info_t));
// Will contain serialized protos, per example.
auto output_debug_info = output_debug_info_t->flat<tstring>();
const int32 last_tree = resource->num_trees() - 1;
// For each given example, traverse through all trees keeping track of the
// features used to split and the associated logits at each point along the
// path. Note: feature_ids has one less value than logits_path because the
// first value of each logit path will be the bias.
auto do_work = [&resource, &batch_bucketized_features, &output_debug_info,
last_tree](int32 start, int32 end) {
for (int32 i = start; i < end; ++i) {
// Proto to store debug outputs, per example.
boosted_trees::DebugOutput example_debug_info;
// Initial bias prediction. E.g., prediction based off training mean.
const auto& tree_logits = resource->node_value(0, 0);
DCHECK_EQ(tree_logits.size(), 1);
float tree_logit = resource->GetTreeWeight(0) * tree_logits[0];
example_debug_info.add_logits_path(tree_logit);
int32 node_id = 0;
int32 tree_id = 0;
int32 feature_id;
float past_trees_logit = 0; // Sum of leaf logits from prior trees.
// Go through each tree and populate proto.
while (tree_id <= last_tree) {
if (resource->is_leaf(tree_id, node_id)) { // Move onto other trees.
// Accumulate tree_logits only if the leaf is non-root, but do so
// for bias tree.
if (tree_id == 0 || node_id > 0) {
past_trees_logit += tree_logit;
}
++tree_id;
node_id = 0;
} else { // Add to proto.
// Feature id used to split.
feature_id = resource->feature_id(tree_id, node_id);
example_debug_info.add_feature_ids(feature_id);
// Get logit after split.
node_id = resource->next_node(tree_id, node_id, i,
batch_bucketized_features);
const auto& tree_logits = resource->node_value(tree_id, node_id);
DCHECK_EQ(tree_logits.size(), 1);
tree_logit = resource->GetTreeWeight(tree_id) * tree_logits[0];
// Output logit incorporates sum of leaf logits from prior trees.
example_debug_info.add_logits_path(tree_logit + past_trees_logit);
}
}
// Set output as serialized proto containing debug info.
string serialized = example_debug_info.SerializeAsString();
output_debug_info(i) = serialized;
}
};
// 10 is the magic number. The actual number might depend on (the number of
// layers in the trees) and (cpu cycles spent on each layer), but this
// value would work for many cases. May be tuned later.
const int64 cost = (last_tree + 1) * 10;
thread::ThreadPool* const worker_threads =
context->device()->tensorflow_cpu_worker_threads()->workers;
Shard(worker_threads->NumThreads(), worker_threads, batch_size,
/*cost_per_unit=*/cost, do_work);
}
private:
int32 logits_dimension_; // Indicates dimension of logits in the tree nodes.
int32 num_bucketized_features_; // Indicates the number of features.
};
REGISTER_KERNEL_BUILDER(
Name("BoostedTreesExampleDebugOutputs").Device(DEVICE_CPU),
BoostedTreesExampleDebugOutputsOp);
} // namespace tensorflow