In batch kernel, use adaptive shared batch scheduler, and share this scheduler across all models. This is only enabled if 'num_batch_threads' is not positive.
PiperOrigin-RevId: 343336941
Change-Id: I8555233c86b3520f50bf21c716cf8c5218a66384
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 36fee01c..32d229e 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -635,9 +635,11 @@
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler",
"//tensorflow/core/kernels/batching_util:batch_resource_base",
"//tensorflow/core/kernels/batching_util:concat_split_util",
"//tensorflow/core/kernels/batching_util:periodic_function_dynamic",
+ "//tensorflow/core/platform:numbers",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc
index 4a53716..5447fac 100644
--- a/tensorflow/core/kernels/batch_kernels.cc
+++ b/tensorflow/core/kernels/batch_kernels.cc
@@ -20,6 +20,7 @@
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
#include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
#include "tensorflow/core/kernels/batching_util/concat_split_util.h"
#include "tensorflow/core/kernels/batching_util/periodic_function.h"
@@ -29,8 +30,15 @@
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/numbers.h"
namespace tensorflow {
+namespace {
+constexpr int64 kMinInflightBatchesLimit = 16;
+constexpr double kInitialInflightBatchesLimit = 64;
+constexpr int64 kBatchesToAverageOver = 10;
+constexpr int64 kMaxInflightBatchesLimit = 128;
+} // namespace
auto* batch_op_split_usage = monitoring::Gauge<string, 1>::New(
"/tensorflow/serving/batching/enable_large_batch_splitting",
@@ -94,6 +102,24 @@
return Status::OK();
}
+ static Status Create(
+ AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options,
+ int32 max_batch_size, int32 batch_timeout_micros,
+ int32 max_enqueued_batches, const std::vector<int32>& allowed_batch_sizes,
+ FunctionLibraryRuntime::Handle fhandle,
+ std::unique_ptr<BatchResource>* resource) {
+ std::shared_ptr<AdaptiveBatcherT> batcher;
+ TF_RETURN_IF_ERROR(AdaptiveBatcherT::Create(
+ adaptive_shared_batch_scheduler_options, &batcher));
+
+ resource->reset(new BatchResource(
+ fhandle, std::move(batcher),
+ GetAdaptiveBatcherQueueOptions(max_batch_size, batch_timeout_micros,
+ max_enqueued_batches, true),
+ allowed_batch_sizes));
+ return Status::OK();
+ }
+
string DebugString() const final { return "BatchResource"; }
private:
@@ -107,6 +133,16 @@
std::move(allowed_batch_sizes)),
fhandle_(fhandle) {}
+ BatchResource(FunctionLibraryRuntime::Handle fhandle,
+ std::shared_ptr<AdaptiveBatcherT> batcher,
+ const AdaptiveBatcherT::QueueOptions& batcher_queue_options,
+ std::vector<int32> allowed_batch_sizes)
+ : BatchResourceBase(
+ /*has_process_batch_function=*/fhandle != kInvalidHandle,
+ std::move(batcher), batcher_queue_options,
+ std::move(allowed_batch_sizes)),
+ fhandle_(fhandle) {}
+
void ProcessFuncBatchImpl(
const BatchTask& last_task, absl::Span<const Tensor> inputs,
std::vector<Tensor>* combined_outputs,
@@ -142,11 +178,6 @@
explicit BatchFunctionKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
- // If shared_name is not supplied, use name instead (prevent collisions by
- // default).
- if (shared_name_.empty()) {
- shared_name_ = name();
- }
OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
@@ -162,6 +193,29 @@
OP_REQUIRES_OK(c, c->GetAttr("f", &func));
OP_REQUIRES_OK(
c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
+ if (num_batch_threads_ <= 0) {
+ adaptive_batch_scheduler_options_ =
+ absl::make_optional(AdaptiveBatchSchedulerOptions{
+ kMinInflightBatchesLimit, kInitialInflightBatchesLimit,
+ kBatchesToAverageOver});
+
+ // Use a shared shared pool across all models if adaptive shared batch
+ // scheduler is used.
+ // `shared_name_` and `container_` is used to look up an instantiated
+ // scheduler instance in `ComputeAsync`.
+ container_ = "__adapative_container";
+ shared_name_ = "__adaptive_global_shared_thread_pool";
+ // Use name to prevent collisions by default.
+ if (batcher_queue_.empty()) {
+ batcher_queue_ = name();
+ }
+ }
+
+ if (shared_name_.empty()) {
+ // If shared_name is not supplied, use name instead (prevent collisions by
+ // default).
+ shared_name_ = name();
+ }
if (c->HasAttr("enable_large_batch_splitting")) {
OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting",
@@ -185,16 +239,58 @@
GetModelName(c));
// TODO(b/173255290): Add num_batch_threads_ parameter to TFRT batch kernel.
RecordBatchParamNumBatchThreads(num_batch_threads_, GetModelName(c));
+
+ std::function<Status(BatchResource**)> creator;
+
+ if (adaptive_batch_scheduler_options_ != absl::nullopt) {
+ creator = [this](BatchResource** r) {
+ serving::AdaptiveSharedBatchScheduler<
+ serving::BatchResourceBase::BatchTask>::Options
+ adaptive_shared_batch_scheduler_options;
+ adaptive_shared_batch_scheduler_options.thread_pool_name =
+ "adaptive_batch_threads";
+ adaptive_shared_batch_scheduler_options.num_batch_threads =
+ kMaxInflightBatchesLimit;
+ // adaptive_shared_batch_scheduler_options.full_batch_scheduling_boost_micros
+ // is 0 (default value) intentionally, so tasks are scheduled in a FIFO
+ // way.
+ // Two rationales to use default value (zero) for
+ // `full_batch_scheduling_boost_micros`
+ // 1) In this way, tasks scheduling policy is FIFO. Compared with round
+ // robin (what shared batch scheduler does), FIFO ensures that model
+ // with low QPS (i.e., models enqueue fewer tasks in the shared queue)
+ // will be processed timely.
+ // 2) If set, `full_batch_scheduling_boost_micros` should be of order
+ // the batch processing latency (which varies on a model basis).
+ // If a non-zero value is not set properly, it harms tail latency.
+ adaptive_shared_batch_scheduler_options.min_in_flight_batches_limit =
+ adaptive_batch_scheduler_options_->min_in_flight_batches_limit;
+ adaptive_shared_batch_scheduler_options
+ .initial_in_flight_batches_limit =
+ adaptive_batch_scheduler_options_->initial_in_flight_batches_limit;
+ adaptive_shared_batch_scheduler_options.batches_to_average_over =
+ adaptive_batch_scheduler_options_->batches_to_average_over;
+ std::unique_ptr<BatchResource> new_resource;
+ TF_RETURN_IF_ERROR(BatchResource::Create(
+ adaptive_shared_batch_scheduler_options, max_batch_size_,
+ batch_timeout_micros_, max_enqueued_batches_, allowed_batch_sizes_,
+ fhandle_, &new_resource));
+ *r = new_resource.release();
+ return Status::OK();
+ };
+ } else {
+ creator = [this](BatchResource** r) {
+ std::unique_ptr<BatchResource> new_resource;
+ TF_RETURN_IF_ERROR(BatchResource::Create(
+ num_batch_threads_, max_batch_size_, batch_timeout_micros_,
+ max_enqueued_batches_, allowed_batch_sizes_, fhandle_,
+ enable_large_batch_splitting_, &new_resource));
+ *r = new_resource.release();
+ return Status::OK();
+ };
+ }
+
BatchResource* br;
- std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
- std::unique_ptr<BatchResource> new_resource;
- TF_RETURN_IF_ERROR(BatchResource::Create(
- num_batch_threads_, max_batch_size_, batch_timeout_micros_,
- max_enqueued_batches_, allowed_batch_sizes_, fhandle_,
- enable_large_batch_splitting_, &new_resource));
- *r = new_resource.release();
- return Status::OK();
- };
OP_REQUIRES_OK_ASYNC(c,
c->resource_manager()->LookupOrCreate(
container_, shared_name_, &br, creator),
@@ -244,6 +340,17 @@
FunctionLibraryRuntime::Handle fhandle_;
bool enable_large_batch_splitting_;
bool has_attribute_enable_large_batch_splitting_;
+
+ // Parameters for adaptive batch scheduler only.
+ // Note 'num_batch_threads_' above is shared by two implementations of batch
+ // scheduler.
+ struct AdaptiveBatchSchedulerOptions {
+ int64 min_in_flight_batches_limit;
+ double initial_in_flight_batches_limit;
+ int64 batches_to_average_over;
+ };
+ absl::optional<AdaptiveBatchSchedulerOptions>
+ adaptive_batch_scheduler_options_ = absl::nullopt;
};
REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
@@ -292,8 +399,8 @@
// Assume br calls done, so nothing to do here.
}
- // Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
- // and the last one must equal 'max_batch_size_'.
+ // Validates 'allowed_batch_sizes_'. The entries must increase
+ // monotonically, and the last one must equal 'max_batch_size_'.
Status ValidateAllowedBatchSizes() const {
if (allowed_batch_sizes_.empty()) {
return Status::OK();
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index 1e546ac..75c270f 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -241,6 +241,7 @@
srcs = ["batch_resource_base.cc"],
hdrs = ["batch_resource_base.h"],
deps = [
+ ":adaptive_shared_batch_scheduler",
":batch_scheduler",
":concat_split_util",
":shared_batch_scheduler",
diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc
index c0bcbb1..e4af643 100644
--- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc
+++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc
@@ -204,7 +204,6 @@
batcher_queue_options.input_batch_size_limit = max_batch_size;
batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
- // Support for splitting large batch is still in progress.
batcher_queue_options.enable_large_batch_splitting =
enable_large_batch_splitting;
if (enable_large_batch_splitting) {
@@ -227,6 +226,28 @@
return batcher_queue_options;
}
+/*static*/ BatchResourceBase::AdaptiveBatcherT::QueueOptions
+BatchResourceBase::GetAdaptiveBatcherQueueOptions(
+ int32 max_batch_size, int32 batch_timeout_micros,
+ int32 max_enqueued_batches, bool enable_large_batch_splitting) {
+ AdaptiveBatcherT::QueueOptions batcher_queue_options;
+ batcher_queue_options.max_batch_size = max_batch_size;
+ batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
+ batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
+
+ if (enable_large_batch_splitting) {
+ batcher_queue_options.split_input_task_func =
+ [](std::unique_ptr<BatchTask>* input_task,
+ int open_batch_remaining_slot, int max_batch_size,
+ std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
+ return SplitInputTask(input_task, open_batch_remaining_slot,
+ max_batch_size, output_tasks);
+ };
+ }
+
+ return batcher_queue_options;
+}
+
/*static*/ Status BatchResourceBase::ValidateBatch(const BatchT& batch) {
for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
const BatchResourceBase::BatchTask& task = batch.task(task_idx);
diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h
index 6fe11c8..ea8c772 100644
--- a/tensorflow/core/kernels/batching_util/batch_resource_base.h
+++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h
@@ -23,6 +23,7 @@
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
#include "tensorflow/core/kernels/batching_util/threadsafe_status.h"
@@ -49,7 +50,7 @@
const string& batcher_queue_name,
AsyncOpKernel::DoneCallback done_callback);
- protected:
+ public:
// One task to be batched, corresponds to a `slice` of input from one batch-op
// invocation.
//
@@ -107,6 +108,8 @@
// tensorflow::serving namespace, because some versions of compiler complain
// about changing meaning of the symbols.
using BatcherT = SharedBatchScheduler<BatchResourceBase::BatchTask>;
+ using AdaptiveBatcherT =
+ AdaptiveSharedBatchScheduler<BatchResourceBase::BatchTask>;
using BatcherQueueT = BatchScheduler<BatchResourceBase::BatchTask>;
using BatchT = Batch<BatchResourceBase::BatchTask>;
@@ -121,11 +124,24 @@
allowed_batch_sizes_str_ = absl::StrJoin(allowed_batch_sizes_, ",");
}
+ BatchResourceBase(bool has_process_batch_function,
+ std::shared_ptr<AdaptiveBatcherT> batcher,
+ const AdaptiveBatcherT::QueueOptions& batcher_queue_options,
+ std::vector<int32> allowed_batch_sizes)
+ : has_process_batch_function_(has_process_batch_function),
+ adaptive_batcher_(std::move(batcher)),
+ adaptive_batcher_queue_options_(batcher_queue_options),
+ allowed_batch_sizes_(std::move(allowed_batch_sizes)) {}
+
static BatcherT::QueueOptions GetBatcherQueueOptions(
int32 num_batch_threads, int32 max_batch_size, int32 batch_timeout_micros,
int32 max_enqueued_batches, const std::vector<int32>& allowed_batch_sizes,
bool enable_large_batch_splitting);
+ static AdaptiveBatcherT::QueueOptions GetAdaptiveBatcherQueueOptions(
+ int32 max_batch_size, int32 batch_timeout_micros,
+ int32 max_enqueued_batches, bool enable_large_batch_splitting);
+
private:
// Implementation of calling the process batch function.
virtual void ProcessFuncBatchImpl(
@@ -199,6 +215,10 @@
std::shared_ptr<BatcherT> batcher_;
BatcherT::QueueOptions batcher_queue_options_;
+ // A batch scheduler, and options for creating queues.
+ std::shared_ptr<AdaptiveBatcherT> adaptive_batcher_;
+ AdaptiveBatcherT::QueueOptions adaptive_batcher_queue_options_;
+
// A collection of batcher queues, keyed on queue name.
// TODO(olston): Garbage-collect unused queues (perhaps simply remove empty
// ones (with a time delay?); it's okay if they get recreated later).