[tf.data] Fix OOM when tf.data map_and_batch is used with num_parallel_calls = autotune, batch_size = 1.
Closes #33516.
PiperOrigin-RevId: 281775472
Change-Id: Ie10cea0ef1515d5aff8e3dddadc069ddee1a5a76
diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index d845d93..f4ad23a 100644
--- a/tensorflow/core/kernels/data/experimental/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -184,6 +184,7 @@
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc
index f765cff..6fbf153 100644
--- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc
@@ -20,6 +20,7 @@
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/common_runtime/metrics.h"
+#include "tensorflow/core/framework/model.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
@@ -170,9 +171,12 @@
num_parallel_calls_(std::make_shared<model::SharedState>(
params.dataset->num_parallel_calls_, mu_, cond_var_)),
max_batch_results_(
- std::min(kMaxBatchResults, (params.dataset->num_parallel_calls_ +
- params.dataset->batch_size_ - 1) /
- params.dataset->batch_size_)) {}
+ params.dataset->num_parallel_calls_ == model::kAutotune
+ ? kMaxBatchResults
+ : std::min(kMaxBatchResults,
+ (params.dataset->num_parallel_calls_ +
+ params.dataset->batch_size_ - 1) /
+ params.dataset->batch_size_)) {}
~Iterator() override {
mutex_lock l(*mu_);