[tf.data] Implement `interleave_depth` value for different iterators, which counts the number of ParallelInterleaveDatasets in the path from root node to the node in the input pipeline tree, not including the node itself.
PiperOrigin-RevId: 394826250
Change-Id: I2aa7d2a56db10e9a1e78573d5318be0d36e1e181
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 06fdc59..8a59cb7 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -395,7 +395,8 @@
split_providers(ctx->split_providers()),
stats_aggregator(ctx->stats_aggregator()),
thread_factory(ctx->thread_factory()),
- thread_pool(ctx->thread_pool()) {}
+ thread_pool(ctx->thread_pool()),
+ interleave_depth(ctx->interleave_depth()) {}
explicit Params(OpKernelContext* ctx)
: collective_executor(ctx->collective_executor()),
@@ -475,6 +476,11 @@
// A shared thread pool to schedule computation into.
thread::ThreadPoolInterface* thread_pool = nullptr;
+
+ // Records the number of ParallelInterleave operations in the path from the
+ // root node to this node (not including this node) in the input pipeline
+ // tree.
+ int64 interleave_depth = 0;
};
explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
@@ -533,6 +539,8 @@
thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
+ int64 interleave_depth() { return params_.interleave_depth; }
+
std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
int num_threads) {
if (params_.thread_pool) {
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 1eaede7..64e8647 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -346,6 +346,7 @@
ctx_ = std::make_unique<IteratorContext>(*ctx);
cancellation_manager_ = absl::make_unique<CancellationManager>();
IteratorContext::Params params(ctx);
+ params.interleave_depth += 1;
params.cancellation_manager = cancellation_manager_.get();
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
IteratorContext(params), this, prefix(), &input_impl_));
@@ -1029,8 +1030,11 @@
}
element->inputs =
absl::make_unique<std::vector<Tensor>>(std::move(inputs));
+ IteratorContext::Params params(ctx_.get());
+ params.interleave_depth += 1;
+ IteratorContext ctx(params);
status = MakeIteratorFromInputElement(
- ctx_.get(), this, *element->inputs, element->id,
+ &ctx, this, *element->inputs, element->id,
*instantiated_captured_func_, prefix(), &element->iterator,
model_node());
if (!status.ok()) {
@@ -1324,8 +1328,11 @@
}
TF_RETURN_IF_ERROR(
reader->ReadScalar(iterator_name, kIdSuffix, &element->id));
+ IteratorContext::Params params(ctx);
+ params.interleave_depth += 1;
+ IteratorContext ctx_copy(params);
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
- ctx, this, *element->inputs, element->id,
+ &ctx_copy, this, *element->inputs, element->id,
*instantiated_captured_func_.get(), prefix(), &iterator,
model_node()));
}