[tf.data] Implement `interleave_depth` value for different iterators, which counts the number of ParallelInterleaveDatasets in the path from root node to the node in the input pipeline tree, not including the node itself.

PiperOrigin-RevId: 394826250
Change-Id: I2aa7d2a56db10e9a1e78573d5318be0d36e1e181
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 06fdc59..8a59cb7 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -395,7 +395,8 @@
           split_providers(ctx->split_providers()),
           stats_aggregator(ctx->stats_aggregator()),
           thread_factory(ctx->thread_factory()),
-          thread_pool(ctx->thread_pool()) {}
+          thread_pool(ctx->thread_pool()),
+          interleave_depth(ctx->interleave_depth()) {}
 
     explicit Params(OpKernelContext* ctx)
         : collective_executor(ctx->collective_executor()),
@@ -475,6 +476,11 @@
 
     // A shared thread pool to schedule computation into.
     thread::ThreadPoolInterface* thread_pool = nullptr;
+
+    // Records the number of ParallelInterleave operations in the path from the
+    // root node to this node (not including this node) in the input pipeline
+    // tree.
+    int64 interleave_depth = 0;
   };
 
   explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
@@ -533,6 +539,8 @@
 
   thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
 
+  int64 interleave_depth() { return params_.interleave_depth; }
+
   std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
                                                        int num_threads) {
     if (params_.thread_pool) {
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 1eaede7..64e8647 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -346,6 +346,7 @@
       ctx_ = std::make_unique<IteratorContext>(*ctx);
       cancellation_manager_ = absl::make_unique<CancellationManager>();
       IteratorContext::Params params(ctx);
+      params.interleave_depth += 1;
       params.cancellation_manager = cancellation_manager_.get();
       TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
           IteratorContext(params), this, prefix(), &input_impl_));
@@ -1029,8 +1030,11 @@
         }
         element->inputs =
             absl::make_unique<std::vector<Tensor>>(std::move(inputs));
+        IteratorContext::Params params(ctx_.get());
+        params.interleave_depth += 1;
+        IteratorContext ctx(params);
         status = MakeIteratorFromInputElement(
-            ctx_.get(), this, *element->inputs, element->id,
+            &ctx, this, *element->inputs, element->id,
             *instantiated_captured_func_, prefix(), &element->iterator,
             model_node());
         if (!status.ok()) {
@@ -1324,8 +1328,11 @@
         }
         TF_RETURN_IF_ERROR(
             reader->ReadScalar(iterator_name, kIdSuffix, &element->id));
+        IteratorContext::Params params(ctx);
+        params.interleave_depth += 1;
+        IteratorContext ctx_copy(params);
         TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
-            ctx, this, *element->inputs, element->id,
+            &ctx_copy, this, *element->inputs, element->id,
             *instantiated_captured_func_.get(), prefix(), &iterator,
             model_node()));
       }