In batch kernel, use adaptive shared batch scheduler, and share this scheduler across all models. This is only enabled if 'num_batch_threads' is not positive.

PiperOrigin-RevId: 343336941
Change-Id: I8555233c86b3520f50bf21c716cf8c5218a66384
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 36fee01c..32d229e 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -635,9 +635,11 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler",
         "//tensorflow/core/kernels/batching_util:batch_resource_base",
         "//tensorflow/core/kernels/batching_util:concat_split_util",
         "//tensorflow/core/kernels/batching_util:periodic_function_dynamic",
+        "//tensorflow/core/platform:numbers",
         "@com_google_absl//absl/strings",
     ],
     alwayslink = 1,
diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc
index 4a53716..5447fac 100644
--- a/tensorflow/core/kernels/batch_kernels.cc
+++ b/tensorflow/core/kernels/batch_kernels.cc
@@ -20,6 +20,7 @@
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_util.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
 #include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
@@ -29,8 +30,15 @@
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/numbers.h"
 
 namespace tensorflow {
+namespace {
+constexpr int64 kMinInflightBatchesLimit = 16;
+constexpr double kInitialInflightBatchesLimit = 64;
+constexpr int64 kBatchesToAverageOver = 10;
+constexpr int64 kMaxInflightBatchesLimit = 128;
+}  // namespace
 
 auto* batch_op_split_usage = monitoring::Gauge<string, 1>::New(
     "/tensorflow/serving/batching/enable_large_batch_splitting",
@@ -94,6 +102,24 @@
     return Status::OK();
   }
 
+  static Status Create(
+      AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options,
+      int32 max_batch_size, int32 batch_timeout_micros,
+      int32 max_enqueued_batches, const std::vector<int32>& allowed_batch_sizes,
+      FunctionLibraryRuntime::Handle fhandle,
+      std::unique_ptr<BatchResource>* resource) {
+    std::shared_ptr<AdaptiveBatcherT> batcher;
+    TF_RETURN_IF_ERROR(AdaptiveBatcherT::Create(
+        adaptive_shared_batch_scheduler_options, &batcher));
+
+    resource->reset(new BatchResource(
+        fhandle, std::move(batcher),
+        GetAdaptiveBatcherQueueOptions(max_batch_size, batch_timeout_micros,
+                                       max_enqueued_batches, true),
+        allowed_batch_sizes));
+    return Status::OK();
+  }
+
   string DebugString() const final { return "BatchResource"; }
 
  private:
@@ -107,6 +133,16 @@
             std::move(allowed_batch_sizes)),
         fhandle_(fhandle) {}
 
+  BatchResource(FunctionLibraryRuntime::Handle fhandle,
+                std::shared_ptr<AdaptiveBatcherT> batcher,
+                const AdaptiveBatcherT::QueueOptions& batcher_queue_options,
+                std::vector<int32> allowed_batch_sizes)
+      : BatchResourceBase(
+            /*has_process_batch_function=*/fhandle != kInvalidHandle,
+            std::move(batcher), batcher_queue_options,
+            std::move(allowed_batch_sizes)),
+        fhandle_(fhandle) {}
+
   void ProcessFuncBatchImpl(
       const BatchTask& last_task, absl::Span<const Tensor> inputs,
       std::vector<Tensor>* combined_outputs,
@@ -142,11 +178,6 @@
   explicit BatchFunctionKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
-    // If shared_name is not supplied, use name instead (prevent collisions by
-    // default).
-    if (shared_name_.empty()) {
-      shared_name_ = name();
-    }
     OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
     OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
     OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
@@ -162,6 +193,29 @@
     OP_REQUIRES_OK(c, c->GetAttr("f", &func));
     OP_REQUIRES_OK(
         c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
+    if (num_batch_threads_ <= 0) {
+      adaptive_batch_scheduler_options_ =
+          absl::make_optional(AdaptiveBatchSchedulerOptions{
+              kMinInflightBatchesLimit, kInitialInflightBatchesLimit,
+              kBatchesToAverageOver});
+
+      // Use a shared shared pool across all models if adaptive shared batch
+      // scheduler is used.
+      // `shared_name_` and `container_` is used to look up an instantiated
+      // scheduler instance in `ComputeAsync`.
+      container_ = "__adapative_container";
+      shared_name_ = "__adaptive_global_shared_thread_pool";
+      // Use name to prevent collisions by default.
+      if (batcher_queue_.empty()) {
+        batcher_queue_ = name();
+      }
+    }
+
+    if (shared_name_.empty()) {
+      // If shared_name is not supplied, use name instead (prevent collisions by
+      // default).
+      shared_name_ = name();
+    }
 
     if (c->HasAttr("enable_large_batch_splitting")) {
       OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting",
@@ -185,16 +239,58 @@
         GetModelName(c));
     // TODO(b/173255290): Add num_batch_threads_ parameter to TFRT batch kernel.
     RecordBatchParamNumBatchThreads(num_batch_threads_, GetModelName(c));
+
+    std::function<Status(BatchResource**)> creator;
+
+    if (adaptive_batch_scheduler_options_ != absl::nullopt) {
+      creator = [this](BatchResource** r) {
+        serving::AdaptiveSharedBatchScheduler<
+            serving::BatchResourceBase::BatchTask>::Options
+            adaptive_shared_batch_scheduler_options;
+        adaptive_shared_batch_scheduler_options.thread_pool_name =
+            "adaptive_batch_threads";
+        adaptive_shared_batch_scheduler_options.num_batch_threads =
+            kMaxInflightBatchesLimit;
+        // adaptive_shared_batch_scheduler_options.full_batch_scheduling_boost_micros
+        // is 0 (default value) intentionally, so tasks are scheduled in a FIFO
+        // way.
+        // Two rationales to use default value (zero) for
+        // `full_batch_scheduling_boost_micros`
+        // 1) In this way, tasks scheduling policy is FIFO. Compared with round
+        // robin (what shared batch scheduler does), FIFO ensures that model
+        // with low QPS (i.e., models enqueue fewer tasks in the shared queue)
+        // will be processed timely.
+        // 2) If set, `full_batch_scheduling_boost_micros` should be of order
+        // the batch processing latency (which varies on a model basis).
+        // If a non-zero value is not set properly, it harms tail latency.
+        adaptive_shared_batch_scheduler_options.min_in_flight_batches_limit =
+            adaptive_batch_scheduler_options_->min_in_flight_batches_limit;
+        adaptive_shared_batch_scheduler_options
+            .initial_in_flight_batches_limit =
+            adaptive_batch_scheduler_options_->initial_in_flight_batches_limit;
+        adaptive_shared_batch_scheduler_options.batches_to_average_over =
+            adaptive_batch_scheduler_options_->batches_to_average_over;
+        std::unique_ptr<BatchResource> new_resource;
+        TF_RETURN_IF_ERROR(BatchResource::Create(
+            adaptive_shared_batch_scheduler_options, max_batch_size_,
+            batch_timeout_micros_, max_enqueued_batches_, allowed_batch_sizes_,
+            fhandle_, &new_resource));
+        *r = new_resource.release();
+        return Status::OK();
+      };
+    } else {
+      creator = [this](BatchResource** r) {
+        std::unique_ptr<BatchResource> new_resource;
+        TF_RETURN_IF_ERROR(BatchResource::Create(
+            num_batch_threads_, max_batch_size_, batch_timeout_micros_,
+            max_enqueued_batches_, allowed_batch_sizes_, fhandle_,
+            enable_large_batch_splitting_, &new_resource));
+        *r = new_resource.release();
+        return Status::OK();
+      };
+    }
+
     BatchResource* br;
-    std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
-      std::unique_ptr<BatchResource> new_resource;
-      TF_RETURN_IF_ERROR(BatchResource::Create(
-          num_batch_threads_, max_batch_size_, batch_timeout_micros_,
-          max_enqueued_batches_, allowed_batch_sizes_, fhandle_,
-          enable_large_batch_splitting_, &new_resource));
-      *r = new_resource.release();
-      return Status::OK();
-    };
     OP_REQUIRES_OK_ASYNC(c,
                          c->resource_manager()->LookupOrCreate(
                              container_, shared_name_, &br, creator),
@@ -244,6 +340,17 @@
   FunctionLibraryRuntime::Handle fhandle_;
   bool enable_large_batch_splitting_;
   bool has_attribute_enable_large_batch_splitting_;
+
+  // Parameters for adaptive batch scheduler only.
+  // Note 'num_batch_threads_' above is shared by two implementations of batch
+  // scheduler.
+  struct AdaptiveBatchSchedulerOptions {
+    int64 min_in_flight_batches_limit;
+    double initial_in_flight_batches_limit;
+    int64 batches_to_average_over;
+  };
+  absl::optional<AdaptiveBatchSchedulerOptions>
+      adaptive_batch_scheduler_options_ = absl::nullopt;
 };
 
 REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
@@ -292,8 +399,8 @@
     // Assume br calls done, so nothing to do here.
   }
 
-  // Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
-  // and the last one must equal 'max_batch_size_'.
+  // Validates 'allowed_batch_sizes_'. The entries must increase
+  // monotonically, and the last one must equal 'max_batch_size_'.
   Status ValidateAllowedBatchSizes() const {
     if (allowed_batch_sizes_.empty()) {
       return Status::OK();
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index 1e546ac..75c270f 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -241,6 +241,7 @@
     srcs = ["batch_resource_base.cc"],
     hdrs = ["batch_resource_base.h"],
     deps = [
+        ":adaptive_shared_batch_scheduler",
         ":batch_scheduler",
         ":concat_split_util",
         ":shared_batch_scheduler",
diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc
index c0bcbb1..e4af643 100644
--- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc
+++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc
@@ -204,7 +204,6 @@
   batcher_queue_options.input_batch_size_limit = max_batch_size;
   batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
   batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
-  // Support for splitting large batch is still in progress.
   batcher_queue_options.enable_large_batch_splitting =
       enable_large_batch_splitting;
   if (enable_large_batch_splitting) {
@@ -227,6 +226,28 @@
   return batcher_queue_options;
 }
 
+/*static*/ BatchResourceBase::AdaptiveBatcherT::QueueOptions
+BatchResourceBase::GetAdaptiveBatcherQueueOptions(
+    int32 max_batch_size, int32 batch_timeout_micros,
+    int32 max_enqueued_batches, bool enable_large_batch_splitting) {
+  AdaptiveBatcherT::QueueOptions batcher_queue_options;
+  batcher_queue_options.max_batch_size = max_batch_size;
+  batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
+  batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
+
+  if (enable_large_batch_splitting) {
+    batcher_queue_options.split_input_task_func =
+        [](std::unique_ptr<BatchTask>* input_task,
+           int open_batch_remaining_slot, int max_batch_size,
+           std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
+      return SplitInputTask(input_task, open_batch_remaining_slot,
+                            max_batch_size, output_tasks);
+    };
+  }
+
+  return batcher_queue_options;
+}
+
 /*static*/ Status BatchResourceBase::ValidateBatch(const BatchT& batch) {
   for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
     const BatchResourceBase::BatchTask& task = batch.task(task_idx);
diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h
index 6fe11c8..ea8c772 100644
--- a/tensorflow/core/kernels/batching_util/batch_resource_base.h
+++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h
@@ -23,6 +23,7 @@
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
 #include "tensorflow/core/kernels/batching_util/threadsafe_status.h"
@@ -49,7 +50,7 @@
                        const string& batcher_queue_name,
                        AsyncOpKernel::DoneCallback done_callback);
 
- protected:
+ public:
   // One task to be batched, corresponds to a `slice` of input from one batch-op
   // invocation.
   //
@@ -107,6 +108,8 @@
   // tensorflow::serving namespace, because some versions of compiler complain
   // about changing meaning of the symbols.
   using BatcherT = SharedBatchScheduler<BatchResourceBase::BatchTask>;
+  using AdaptiveBatcherT =
+      AdaptiveSharedBatchScheduler<BatchResourceBase::BatchTask>;
   using BatcherQueueT = BatchScheduler<BatchResourceBase::BatchTask>;
   using BatchT = Batch<BatchResourceBase::BatchTask>;
 
@@ -121,11 +124,24 @@
     allowed_batch_sizes_str_ = absl::StrJoin(allowed_batch_sizes_, ",");
   }
 
+  BatchResourceBase(bool has_process_batch_function,
+                    std::shared_ptr<AdaptiveBatcherT> batcher,
+                    const AdaptiveBatcherT::QueueOptions& batcher_queue_options,
+                    std::vector<int32> allowed_batch_sizes)
+      : has_process_batch_function_(has_process_batch_function),
+        adaptive_batcher_(std::move(batcher)),
+        adaptive_batcher_queue_options_(batcher_queue_options),
+        allowed_batch_sizes_(std::move(allowed_batch_sizes)) {}
+
   static BatcherT::QueueOptions GetBatcherQueueOptions(
       int32 num_batch_threads, int32 max_batch_size, int32 batch_timeout_micros,
       int32 max_enqueued_batches, const std::vector<int32>& allowed_batch_sizes,
       bool enable_large_batch_splitting);
 
+  static AdaptiveBatcherT::QueueOptions GetAdaptiveBatcherQueueOptions(
+      int32 max_batch_size, int32 batch_timeout_micros,
+      int32 max_enqueued_batches, bool enable_large_batch_splitting);
+
  private:
   // Implementation of calling the process batch function.
   virtual void ProcessFuncBatchImpl(
@@ -199,6 +215,10 @@
   std::shared_ptr<BatcherT> batcher_;
   BatcherT::QueueOptions batcher_queue_options_;
 
+  // A batch scheduler, and options for creating queues.
+  std::shared_ptr<AdaptiveBatcherT> adaptive_batcher_;
+  AdaptiveBatcherT::QueueOptions adaptive_batcher_queue_options_;
+
   // A collection of batcher queues, keyed on queue name.
   // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty
   // ones (with a time delay?); it's okay if they get recreated later).