blob: e056b5e4746b6730913cd9782d095535a5943d73 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
#include "tensorflow/core/framework/ops_util.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/kernels/batching_util/concat_split_util.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/percentile_sampler.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/profiler/lib/traceme_encode.h"
#include "tensorflow/core/util/incremental_barrier.h"
namespace tensorflow {
namespace serving {
namespace {
void RecordPaddingSize(int32 padding_size, const string& model_name,
int32 execution_batch_size) {
static auto* cell = tensorflow::monitoring::PercentileSampler<2>::New(
{"/tensorflow/serving/batching/padding_size",
"Tracks the padding size distribution on batches by model_name (if "
"available).",
"model_name", "execution_batch_size"},
/*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
/*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
cell->GetCell(model_name, absl::StrCat(execution_batch_size))
->Add(static_cast<double>(padding_size));
}
void RecordInputBatchSize(int32 batch_size, const string& model_name) {
static auto* cell = tensorflow::monitoring::PercentileSampler<1>::New(
{"/tensorflow/serving/batching/input_batch_size",
"Tracks the batch size distribution on the inputs by model_name (if "
"available).",
"model_name"},
/*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
/*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
cell->GetCell(model_name)->Add(static_cast<double>(batch_size));
}
void RecordProcessedBatchSize(int32 batch_size, const string& model_name) {
static auto* cell = tensorflow::monitoring::PercentileSampler<1>::New(
{"/tensorflow/serving/batching/processed_batch_size",
"Tracks the batch size distribution on processing by model_name (if "
"available).",
"model_name"},
/*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
/*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
cell->GetCell(model_name)->Add(static_cast<double>(batch_size));
}
void RecordBatchDelayUs(int64 batch_delay_us, const string& model_name) {
static auto* cell = monitoring::PercentileSampler<1>::New(
{"/tensorflow/serving/batching/batch_delay_us",
"Tracks the batching delay (in microseconds) for inputs by model_name "
"(if available).",
"model_name"},
/*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
/*max_samples=*/1024, monitoring::UnitOfMeasure::kTime);
cell->GetCell(model_name)->Add(static_cast<double>(batch_delay_us));
}
void RecordBatchParamBatchTimeoutMicros(int64 batch_timeout_micros,
const string& model_name) {
static auto* cell = monitoring::Gauge<int64, 1>::New(
"/tensorflow/serving/batching/batch_timeout_micros",
"Tracks how long a request can wait before being processed by a batch.",
"model_name");
cell->GetCell(model_name)->Set(batch_timeout_micros);
}
void RecordBatchParamMaxBatchSize(int64 max_batch_size,
const string& model_name) {
static auto* cell = monitoring::Gauge<int64, 1>::New(
"/tensorflow/serving/batching/max_batch_size",
"Tracks the maximum size of a batch.", "model_name");
cell->GetCell(model_name)->Set(max_batch_size);
}
void RecordBatchParamMaxEnqueuedBatches(int64 max_enqueued_batches,
const string& model_name) {
static auto* cell = monitoring::Gauge<int64, 1>::New(
"/tensorflow/serving/batching/max_enqueued_batches",
"Tracks the maximum number of enqueued batches.", "model_name");
cell->GetCell(model_name)->Set(max_enqueued_batches);
}
void RecordBatchParamAllowedBatchSizes(const string& allowed_batch_sizes,
const string& model_name) {
static auto* cell = monitoring::Gauge<string, 1>::New(
"/tensorflow/serving/batching/allowed_batch_sizes",
"Tracks the sizes that are allowed to form a batch.", "model_name");
cell->GetCell(model_name)->Set(allowed_batch_sizes);
}
const string& GetModelName(OpKernelContext* ctx) {
static string* kModelNameUnset = new string("model_name_unset");
if (!ctx->session_metadata()) return *kModelNameUnset;
if (ctx->session_metadata()->name().empty()) return *kModelNameUnset;
return ctx->session_metadata()->name();
}
} // namespace
std::unique_ptr<BatchResourceBase::BatchTask>
BatchResourceBase::BatchTask::CreateSplitTask(
int split_index, AsyncOpKernel::DoneCallback done_callback) {
std::unique_ptr<BatchTask> task = CreateDerivedTask();
task->guid = this->guid;
task->propagated_context = Context(ContextKind::kThread);
task->inputs.reserve(this->inputs.size());
task->captured_inputs = this->captured_inputs;
task->context = this->context;
task->done_callback = done_callback;
task->split_index = split_index;
task->output = this->output;
task->status = this->status;
task->is_partial = true;
task->start_time = this->start_time;
return task;
}
using ::tensorflow::concat_split_util::Concat;
using ::tensorflow::concat_split_util::Split;
using TensorMatrix = std::vector<std::vector<Tensor>>;
Status BatchResourceBase::RegisterInput(
int64 guid, OpKernelContext* context, const string& batcher_queue_name,
AsyncOpKernel::DoneCallback done_callback) {
std::unique_ptr<BatchTask> batch_components;
TF_RETURN_IF_ERROR(CreateBatchTask(context, &batch_components));
batch_components->start_time = EnvTime::NowNanos();
batch_components->guid = guid;
batch_components->propagated_context = Context(ContextKind::kThread);
OpInputList tensors;
TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors));
batch_components->inputs.reserve(tensors.size());
for (const Tensor& tensor : tensors) {
if (tensor.shape().dims() == 0) {
return errors::InvalidArgument(
"Batching input tensors must have at least one dimension");
}
if (tensors.size() >= 2 &&
tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) {
return errors::InvalidArgument(
"Batching input tensors supplied in a given op invocation must "
"have equal 0th-dimension size");
}
batch_components->inputs.push_back(tensor);
}
RecordInputBatchSize(tensors[0].shape().dim_size(0), GetModelName(context));
RecordBatchParamBatchTimeoutMicros(
batcher_queue_options_.batch_timeout_micros, GetModelName(context));
RecordBatchParamMaxBatchSize(batcher_queue_options_.max_execution_batch_size,
GetModelName(context));
RecordBatchParamMaxEnqueuedBatches(
batcher_queue_options_.max_enqueued_batches, GetModelName(context));
RecordBatchParamAllowedBatchSizes(allowed_batch_sizes_str_,
GetModelName(context));
OpInputList captured_tensors;
const auto captured_status =
context->input_list("captured_tensors", &captured_tensors);
if (captured_status.ok()) {
batch_components->captured_inputs.reserve(captured_tensors.size());
for (const Tensor& captured_tensor : captured_tensors) {
batch_components->captured_inputs.push_back(captured_tensor);
}
}
batch_components->context = context;
batch_components->done_callback = std::move(done_callback);
batch_components->split_index = 0;
batch_components->output = std::make_shared<TensorMatrix>();
batch_components->status = std::make_shared<ThreadSafeStatus>();
BatcherQueueT* batcher_queue;
TF_RETURN_IF_ERROR(
LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue));
return batcher_queue->Schedule(&batch_components);
}
/*static*/ BatchResourceBase::BatcherT::QueueOptions
BatchResourceBase::GetBatcherQueueOptions(
int32 num_batch_threads, int32 max_batch_size, int32 batch_timeout_micros,
int32 max_enqueued_batches, const std::vector<int32>& allowed_batch_sizes,
bool enable_large_batch_splitting) {
BatcherT::QueueOptions batcher_queue_options;
batcher_queue_options.input_batch_size_limit = max_batch_size;
batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
batcher_queue_options.enable_large_batch_splitting =
enable_large_batch_splitting;
if (enable_large_batch_splitting) {
batcher_queue_options.split_input_task_func =
[](std::unique_ptr<BatchTask>* input_task,
int open_batch_remaining_slot, int max_batch_size,
std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
return SplitInputTask(input_task, open_batch_remaining_slot,
max_batch_size, output_tasks);
};
if (allowed_batch_sizes.empty()) {
batcher_queue_options.max_execution_batch_size = max_batch_size;
} else {
batcher_queue_options.max_execution_batch_size =
*allowed_batch_sizes.rbegin();
}
}
return batcher_queue_options;
}
/*static*/ BatchResourceBase::AdaptiveBatcherT::QueueOptions
BatchResourceBase::GetAdaptiveBatcherQueueOptions(
int32 max_batch_size, int32 batch_timeout_micros,
int32 max_enqueued_batches, bool enable_large_batch_splitting) {
AdaptiveBatcherT::QueueOptions batcher_queue_options;
batcher_queue_options.max_batch_size = max_batch_size;
batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
if (enable_large_batch_splitting) {
batcher_queue_options.split_input_task_func =
[](std::unique_ptr<BatchTask>* input_task,
int open_batch_remaining_slot, int max_batch_size,
std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
return SplitInputTask(input_task, open_batch_remaining_slot,
max_batch_size, output_tasks);
};
}
return batcher_queue_options;
}
/*static*/ Status BatchResourceBase::ValidateBatch(const BatchT& batch) {
for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
const BatchResourceBase::BatchTask& task = batch.task(task_idx);
if (task.inputs.size() != batch.task(0).inputs.size()) {
return errors::InvalidArgument(
"Batching inputs must have equal number of edges");
}
}
return Status::OK();
}
// Returns the smallest entry in 'allowed_batch_sizes_' that is greater than
// or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply
// returns 'batch_size'.
int BatchResourceBase::RoundToLowestAllowedBatchSize(int batch_size) const {
if (allowed_batch_sizes_.empty()) {
return batch_size;
}
for (int allowed_size : allowed_batch_sizes_) {
if (allowed_size >= batch_size) {
return allowed_size;
}
}
LOG(ERROR) << "Maximum batch size greater than largest allowed size; "
"ignoring allowed sizes constraint";
return batch_size;
}
Status BatchResourceBase::ConcatInputTensors(
const BatchT& batch, OpKernelContext* context,
std::vector<Tensor>* concatenated_tensors) const {
if (batch.num_tasks() == 0) {
return errors::InvalidArgument("Empty batch.");
}
const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size());
const int padding_amount = padded_batch_size - batch.size();
profiler::TraceMe trace_me([padded_batch_size, padding_amount]() {
return profiler::TraceMeEncode(
"ConcatInputTensors", {{"batch_size_after_padding", padded_batch_size},
{"padding_amount", padding_amount}});
});
RecordPaddingSize(padding_amount, GetModelName(context), padded_batch_size);
RecordProcessedBatchSize(padded_batch_size, GetModelName(context));
// All tasks should have the same number of input edges.
const int num_inputs = batch.task(0).inputs.size();
concatenated_tensors->reserve(num_inputs);
// Process each input one at a time (the typical case has just one).
for (int i = 0; i < num_inputs; ++i) {
// Concatenate the tasks ith input tensors into a big output tensor.
std::vector<Tensor> to_concatenate;
to_concatenate.reserve(batch.num_tasks());
for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
to_concatenate.push_back(batch.task(task_idx).inputs.at(i));
}
// Add padding as needed. Use the first row of the first task's tensor as
// the data for padding.
if (padding_amount > 0) {
const Tensor& padding_source = batch.task(0).inputs.at(i);
Tensor padding;
if (padding_source.shape().dim_size(0) == 0) {
return errors::InvalidArgument(
"Cannot use an empty tensor with zero rows as padding when "
"batching. (Input ",
i, " got shape ", padding_source.shape().DebugString(), ".)");
}
if (padding_source.shape().dim_size(0) == 1) {
padding = padding_source;
} else {
padding = padding_source.Slice(0, 1);
}
for (int i = 0; i < padding_amount; ++i) {
to_concatenate.push_back(padding);
}
}
Tensor concatenated_tensor;
Status concat_status =
Concat(context, to_concatenate, &concatenated_tensor);
TF_RETURN_IF_ERROR(concat_status);
concatenated_tensors->push_back(concatenated_tensor);
}
return Status::OK();
}
/*static*/ Status BatchResourceBase::SplitInputTask(
std::unique_ptr<BatchTask>* input_task_ptr, int open_batch_remaining_slot,
int max_batch_size, std::vector<std::unique_ptr<BatchTask>>* output_tasks) {
BatchTask& input_task = *(*input_task_ptr);
const int64 input_task_size = input_task.size();
DCHECK_GT(input_task_size, open_batch_remaining_slot);
std::shared_ptr<ThreadSafeStatus> shared_status = input_task.status;
// `split_task_done_callback` runs only after all splitted tasks are
// complete.
std::function<void()> split_task_done_callback =
[done_callback = input_task.done_callback, output = input_task.output,
op_kernel_context = input_task.context, status = shared_status]() {
const int num_output = op_kernel_context->num_outputs();
for (int i = 0; i < num_output; ++i) {
Tensor output_tensor;
// Concat would memcpy each input tensor to one output tensor.
// In this context, Concat can be further optimized to get rid of
// some (probably all) memcpy when input tensors are slices of
// another copy.
std::vector<Tensor> to_concatenate;
to_concatenate.reserve(output->size());
for (int j = 0; j < output->size(); ++j) {
to_concatenate.push_back(std::move((*output)[j][i]));
}
const auto concat_status =
Concat(op_kernel_context, to_concatenate, &output_tensor);
if (!concat_status.ok()) {
status->Update(concat_status);
}
op_kernel_context->set_output(i, std::move(output_tensor));
}
op_kernel_context->SetStatus(status->status());
done_callback();
};
IncrementalBarrier barrier(split_task_done_callback);
std::vector<int64> output_task_sizes;
if (open_batch_remaining_slot > 0) {
output_task_sizes.push_back(open_batch_remaining_slot);
}
for (int left_task_size = input_task_size - open_batch_remaining_slot;
left_task_size > 0; left_task_size -= max_batch_size) {
int next_task_size = std::min(left_task_size, max_batch_size);
output_task_sizes.push_back(next_task_size);
}
const int output_task_num = output_task_sizes.size();
input_task.output->resize(output_task_num);
for (int i = 0; i < output_task_num; ++i) {
(*input_task.output)[i].resize(input_task.context->num_outputs());
}
output_tasks->reserve(output_task_num);
for (int i = 0; i < output_task_num; i++) {
output_tasks->push_back(input_task.CreateSplitTask(i, barrier.Inc()));
}
const int num_input_tensors = input_task.inputs.size();
// Splits each input tensor according to `output_task_sizes`, and
// initializes input of `output_tasks` with split results.
for (int i = 0; i < num_input_tensors; ++i) {
std::vector<Tensor> split_tensors;
const Tensor& input_tensor = input_task.inputs[i];
// TODO(b/154140947):
// Figure out the optimal implementation of Split, by using
// 'Tensor::Slice' and eliminating unnecessary memcpy as much as possible.
const Status split_status = Split(input_task.context, input_tensor,
output_task_sizes, &split_tensors);
if (!split_status.ok()) {
return errors::Internal(
"When splitting input, Tensor split operation failed: ",
split_status.ToString());
}
if (split_tensors.size() != output_task_sizes.size()) {
return errors::Internal(
"When splitting input, tensor split operation did not work as "
"expected; got ",
split_tensors.size(), " splits; expected ", output_task_sizes.size());
}
for (int j = 0; j < output_tasks->size(); ++j) {
BatchTask& output_task = *((*output_tasks)[j]);
auto moved_tensor_iter = std::next(split_tensors.begin(), j);
std::move(moved_tensor_iter, moved_tensor_iter + 1,
std::back_inserter(output_task.inputs));
}
}
return Status::OK();
}
Status BatchResourceBase::SplitOutputTensors(
const std::vector<Tensor>& combined_outputs, BatchT* batch) const {
DCHECK_GE(batch->num_tasks(), 1);
if (batch->num_tasks() < 1) {
return errors::Internal("Batch size expected to be positive; was ",
batch->num_tasks());
}
std::vector<int64> task_sizes_plus_optional_padding;
task_sizes_plus_optional_padding.reserve(batch->num_tasks());
for (int i = 0; i < batch->num_tasks(); ++i) {
task_sizes_plus_optional_padding.push_back(batch->task(i).size());
}
const int padding_size =
RoundToLowestAllowedBatchSize(batch->size()) - batch->size();
if (padding_size > 0) {
task_sizes_plus_optional_padding.push_back(padding_size);
}
// For each output tensor name, a divided-up tensor with one entry per task.
std::map<string, std::vector<Tensor>> split_tensors;
DCHECK_EQ(batch->task(0).context->num_outputs(), combined_outputs.size());
int combined_outputs_size = combined_outputs.size();
if (combined_outputs_size != batch->task(0).context->num_outputs()) {
return errors::Internal("Wrong number of batched output tensors");
}
// Generate 'split_tensors' and populate the context outputs.
for (int i = 0, iter_limit = combined_outputs.size(); i < iter_limit; ++i) {
const Tensor& output_tensor = combined_outputs[i];
if (output_tensor.shape().dims() == 0) {
return errors::FailedPrecondition(
"Batched output tensor has 0 dimensions");
}
if (output_tensor.shape().dim_size(0) !=
static_cast<int64>(batch->size() + padding_size)) {
return errors::FailedPrecondition(
"Batched output tensor's 0th dimension does not equal the sum of "
"the 0th dimension sizes of the input tensors");
}
std::vector<Tensor> split_tensor;
const Status split_status = tensor::Split(
output_tensor, task_sizes_plus_optional_padding, &split_tensor);
DCHECK(split_status.ok()) << split_status.ToString();
if (!split_status.ok()) {
return errors::Internal("Tensor split operation failed: ",
split_status.ToString());
}
DCHECK_EQ(split_tensor.size(), task_sizes_plus_optional_padding.size());
if (split_tensor.size() != task_sizes_plus_optional_padding.size()) {
return errors::Internal(
"Tensor split operation did not work as expected; got ",
split_tensor.size(), " splits; expected ",
task_sizes_plus_optional_padding.size());
}
// Ignore a possible final split_tensors entry containing the padding.
for (int j = 0; j < batch->num_tasks(); ++j) {
BatchTask& task = *(batch->mutable_task(j));
if (task.is_partial) {
std::vector<Tensor>& tensor_vector = (*task.output)[task.split_index];
tensor_vector[i] = std::move(split_tensor[j]);
} else {
task.context->set_output(i, split_tensor[j]);
}
}
}
return Status::OK();
}
void BatchResourceBase::ProcessFuncBatch(std::unique_ptr<BatchT> batch) const {
if (batch->empty()) {
return;
}
// We use the 'propagated_context' from one of the threads which setup one
// of the tasks. This will propagate any common context over all the threads
// which are running this Session, of which this BatchOp is a part.
WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
auto& last_task = batch->task(batch->num_tasks() - 1);
OpKernelContext* last_task_context = last_task.context;
// Regardless of the outcome, we need to propagate the status to the
// individual tasks and signal that they are done. We use MakeCleanup() to
// ensure that this happens no matter how we exit the method below.
Status status;
bool cleanup_done = false;
auto cleanup_fn = [&cleanup_done, &batch](const Status& status) {
if (cleanup_done) {
return;
}
for (int i = 0; i < batch->num_tasks(); ++i) {
if (batch->task(i).is_partial) {
batch->mutable_task(i)->status->Update(status);
} else {
batch->mutable_task(i)->context->SetStatus(status);
}
batch->mutable_task(i)->done_callback();
}
cleanup_done = true;
};
auto finally =
gtl::MakeCleanup([&cleanup_fn, &status] { cleanup_fn(status); });
status = ValidateBatch(*batch);
if (!status.ok()) {
return;
}
std::vector<Tensor> concatenated_tensors;
status = ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
if (!status.ok()) {
return;
}
std::vector<Tensor> combined_outputs;
std::vector<Tensor> args(concatenated_tensors.begin(),
concatenated_tensors.end());
const auto& captured_inputs =
batch->task(batch->num_tasks() - 1).captured_inputs;
args.insert(args.end(), captured_inputs.begin(), captured_inputs.end());
uint64 current_time = EnvTime::NowNanos();
const string& model_name = GetModelName(last_task_context);
for (int i = 0; i < batch->num_tasks(); ++i) {
RecordBatchDelayUs((current_time - batch->task(i).start_time) * 1e-3,
model_name);
}
// Releases the cleanup method here, because the callback of the function
// library runtime will handle it now.
finally.release();
ProcessFuncBatchImpl(
last_task, args, &combined_outputs, [&](const Status& run_status) {
Status final_status;
auto run_finally = gtl::MakeCleanup([&]() {
// We do the cleanup here as an optimization, so that
// it runs in the underlying TF inter-op threadpool.
// Running it in the threadpool, let's the ensuing
// ops be scheduled faster, because the executor will
// add them to the front of the threadpool's task
// queue rather than the end.
cleanup_fn(final_status);
});
final_status = run_status;
if (!final_status.ok()) {
return;
}
final_status = SplitOutputTensors(combined_outputs, batch.get());
});
}
// Processes a batch of one or more BatchTask entries.
void BatchResourceBase::ProcessBatch(std::unique_ptr<BatchT> batch) const {
if (batch->empty()) {
return;
}
WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
OpKernelContext* last_task_context =
batch->task(batch->num_tasks() - 1).context;
AsyncOpKernel::DoneCallback last_task_callback =
batch->task(batch->num_tasks() - 1).done_callback;
OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch),
last_task_callback);
// All tasks should have the same number of input edges.
const int num_input_edges = batch->task(0).inputs.size();
std::vector<Tensor> concatenated_tensors;
const Status concat_status =
ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, last_task_callback);
// Process each input edge one at a time (the typical case has just one).
for (int i = 0; i < num_input_edges; ++i) {
last_task_context->set_output(i, concatenated_tensors[i]);
// Emit batch->num_tasks() - 1 empty output tensors.
for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
const BatchTask& task = batch->task(task_idx);
TensorShape output_shape(task.inputs[i].shape());
output_shape.set_dim(0, 0);
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
task.context, task.context->allocate_output(i, output_shape, &output),
task.done_callback);
}
}
// Emit batch->num_tasks() - 1 empty index tensors.
for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
const BatchTask& task = batch->task(task_idx);
TensorShape index_shape({0, 3});
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
task.context,
task.context->allocate_output(num_input_edges, index_shape, &output),
task.done_callback);
}
// Emit all ID tensors.
for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
const BatchTask& task = batch->task(task_idx);
Tensor* id;
OP_REQUIRES_OK_ASYNC(task.context,
task.context->allocate_output(num_input_edges + 1,
TensorShape({}), &id),
task.done_callback);
id->scalar<int64>()() = task.guid;
}
OP_REQUIRES_OK_ASYNC(
last_task_context,
EmitIndexTensor(last_task_context, *batch, num_input_edges),
last_task_callback);
// Signal done for each element of the batch. (At this point, the contexts
// are no longer guaranteed to remain live.)
for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
batch->mutable_task(task_idx)->done_callback();
}
}
/*static*/ Status BatchResourceBase::EmitIndexTensor(OpKernelContext* context,
const BatchT& batch,
int output_index) {
const TensorShape index_shape({batch.num_tasks(), 3});
Tensor* index = nullptr;
TF_RETURN_IF_ERROR(
context->allocate_output(output_index, index_shape, &index));
auto index_flat = index->shaped<int64, 2>({batch.num_tasks(), 3});
size_t offset = 0;
for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
const BatchTask& task = batch.task(task_idx);
index_flat(task_idx, 0) = task.guid;
index_flat(task_idx, 1) = offset;
index_flat(task_idx, 2) = offset + task.size();
offset += task.size();
}
return Status::OK();
}
// Looks up the batcher queue for 'queue_name'. If it did't previously exist,
// creates it.
Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name,
BatcherQueueT** queue) {
mutex_lock l(batcher_queues_mu_);
auto it = batcher_queues_.find(queue_name);
if (it != batcher_queues_.end()) {
*queue = it->second.get();
return Status::OK();
}
std::unique_ptr<BatcherQueueT> new_queue;
auto process_batch_callback = [this](std::unique_ptr<BatchT> batch) {
if (!has_process_batch_function_) {
ProcessBatch(std::move(batch));
} else {
ProcessFuncBatch(std::move(batch));
}
};
TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_,
process_batch_callback, &new_queue));
*queue = new_queue.get();
batcher_queues_[queue_name] = std::move(new_queue);
return Status::OK();
}
Status BatchResourceBase::CreateBatchTask(
OpKernelContext* context,
std::unique_ptr<BatchResourceBase::BatchTask>* output) const {
*output = absl::make_unique<BatchResourceBase::BatchTask>();
return Status::OK();
}
} // namespace serving
} // namespace tensorflow