[tf.data] Improves consistency of IteratorContext life cycle management. After this change, background threads of asynchronous op kernels own a copy of IteratorContext.

PiperOrigin-RevId: 396939365
Change-Id: Ie0ffff34a808036e5d60fbe980a28ac757e0eed1
diff --git a/tensorflow/core/data/captured_function.cc b/tensorflow/core/data/captured_function.cc
index 8272f5d..1d60519 100644
--- a/tensorflow/core/data/captured_function.cc
+++ b/tensorflow/core/data/captured_function.cc
@@ -389,11 +389,11 @@
     const std::vector<Tensor>& input_element, int64_t thread_index,
     const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
     std::unique_ptr<IteratorBase>* out_iterator,
-    std::shared_ptr<model::Node> node) {
+    const std::shared_ptr<model::Node>& node) {
   std::vector<Tensor> return_values;
 
   TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(
-      ctx, input_element, &return_values, std::move(node)));
+      ctx, input_element, &return_values, node));
 
   if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
         TensorShapeUtils::IsScalar(return_values[0].shape()))) {
@@ -773,7 +773,7 @@
 
 Status InstantiatedCapturedFunction::Run(
     IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
-    std::shared_ptr<model::Node> node) const {
+    const std::shared_ptr<model::Node>& node) const {
   auto& info = captured_func_->short_circuit_info();
   if (!info.indices.empty()) {
     return RunShortCircuit(info, std::move(args), captured_func_, rets);
@@ -836,7 +836,7 @@
 
 Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
     IteratorContext* ctx, const std::vector<Tensor>& args,
-    std::vector<Tensor>* rets, std::shared_ptr<model::Node> node) const {
+    std::vector<Tensor>* rets, const std::shared_ptr<model::Node>& node) const {
   auto& info = captured_func_->short_circuit_info();
   if (!info.indices.empty()) {
     return RunShortCircuit(info, args, captured_func_, rets);
@@ -921,13 +921,10 @@
   return frame.ConsumeRetvals(rets);
 }
 
-// NOTE: The `done` callback will be invoked asynchronously from the calling
-// thread. The caller is therefore responsible for making sure that any objects
-// accessed by the callback exist at least until the callback returns.
 void InstantiatedCapturedFunction::RunAsync(
     IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
     FunctionLibraryRuntime::DoneCallback done,
-    std::shared_ptr<model::Node> node) const {
+    const std::shared_ptr<model::Node>& node) const {
   auto& info = captured_func_->short_circuit_info();
   if (!info.indices.empty()) {
     // Run the `done` callback on a threadpool thread, because it will
@@ -940,6 +937,9 @@
     return;
   }
 
+  // NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
+  // be deleted before `done` is called. Take care not to capture `ctx` in any
+  // code that may execute asynchronously in this function.
   OwnedArgsCallFrame* frame = new OwnedArgsCallFrame(
       std::move(args), &captured_func_->captured_inputs(), ret_types_);
 
@@ -953,7 +953,7 @@
   f_opts.runner = ctx->runner();
   f_opts.create_rendezvous = ShouldCreateRendezvous();
   auto cancellation_manager =
-      std::make_shared<CancellationManager>(ctx->cancellation_manager());
+      absl::make_unique<CancellationManager>(ctx->cancellation_manager());
   f_opts.cancellation_manager = cancellation_manager.get();
   f_opts.collective_executor = ctx->collective_executor();
 
@@ -964,13 +964,19 @@
   const bool collect_usage = node && ctx->model();
   f_opts.stats_collector = stats_collector.get();
 
-  // Transferring ownership of `step_container` and `frame` into `callback`.
-  auto callback =
-      [this, stats_collector = std::move(stats_collector),
-       stats_aggregator = ctx->stats_aggregator(), done = std::move(done),
-       cancellation_manager = std::move(cancellation_manager), node, rets,
-       step_container, frame, collect_usage](Status s) {
+  // Transfer ownership of the cancellation manager to `callback`.
+  CancellationManager* raw_cancellation_manager =
+      cancellation_manager.release();
+  auto callback = std::bind(
+      [this, rets, step_container, raw_cancellation_manager, frame, node,
+       collect_usage](
+          const FunctionLibraryRuntime::DoneCallback& done,
+          IteratorContext* ctx,
+          const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
+          // Begin unbound arguments.
+          Status s) {
         delete step_container;
+        delete raw_cancellation_manager;
         if (s.ok()) {
           s = frame->ConsumeRetvals(rets);
         }
@@ -979,11 +985,11 @@
           // TODO(b/129085499) Utilize the `node_name` which would be unique
           // than the prefix for the function execution time statistics.
           // prefix_with_func_name would then be node_name + func_name.
-          if (stats_aggregator) {
+          if (ctx->stats_aggregator()) {
             string prefix_with_func_name =
                 strings::StrCat(node->name(), stats_utils::kDelimiter,
                                 captured_func_->func().name());
-            stats_aggregator->AddToHistogram(
+            ctx->stats_aggregator()->AddToHistogram(
                 stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
                 {static_cast<float>(stats_collector->processing_time())},
                 node->num_elements());
@@ -997,7 +1003,8 @@
         if (collect_usage) {
           node->record_stop(EnvTime::NowNanos());
         }
-      };
+      },
+      std::move(done), ctx, std::move(stats_collector), std::placeholders::_1);
 
   profiler::TraceMe activity(
       [&] {
diff --git a/tensorflow/core/data/captured_function.h b/tensorflow/core/data/captured_function.h
index afa197f..82d2b1a 100644
--- a/tensorflow/core/data/captured_function.h
+++ b/tensorflow/core/data/captured_function.h
@@ -56,7 +56,7 @@
     const std::vector<Tensor>& input_element, int64_t thread_index,
     const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
     std::unique_ptr<IteratorBase>* out_iterator,
-    std::shared_ptr<model::Node> node);
+    const std::shared_ptr<model::Node>& node);
 
 struct ShortCircuitInfo {
   std::vector<int> indices;
@@ -246,7 +246,7 @@
   // called `DatasetBaseIterator::RecordStart().
   Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
              std::vector<Tensor>* rets,
-             std::shared_ptr<model::Node> node) const;
+             const std::shared_ptr<model::Node>& node) const;
 
   // Synchronously runs the captured function on the given `args`, and stores
   // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
@@ -264,7 +264,7 @@
   Status RunWithBorrowedArgs(IteratorContext* ctx,
                              const std::vector<Tensor>& args,
                              std::vector<Tensor>* rets,
-                             std::shared_ptr<model::Node> node) const;
+                             const std::shared_ptr<model::Node>& node) const;
 
   // Synchronously runs the captured function on the given `args`, and stores
   // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
@@ -285,7 +285,7 @@
   void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
                 std::vector<Tensor>* rets,
                 FunctionLibraryRuntime::DoneCallback done,
-                std::shared_ptr<model::Node> node) const;
+                const std::shared_ptr<model::Node>& node) const;
 
  private:
   friend class CapturedFunction;
diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc
index c323b0a..fe71f21 100644
--- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc
@@ -77,9 +77,6 @@
 // Computes ceil(x / y).
 inline int64_t CeilDiv(int64_t x, int64_t y) { return (x + y - 1) / y; }
 
-// Period between reporting dataset statistics.
-constexpr int kStatsReportingPeriodMillis = 1000;
-
 }  // namespace
 
 class MapAndBatchDatasetOp::Dataset : public DatasetBase {
@@ -381,14 +378,25 @@
       const uint64 uid = -1;
     };
 
-    void CallCompleted(BatchResult* result) TF_LOCKS_EXCLUDED(*mu_) {
+    void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
+                       const std::shared_ptr<BatchResult>& result)
+        TF_LOCKS_EXCLUDED(*mu_) {
       mutex_lock l(*mu_);
       num_calls_--;
       result->num_calls--;
+      const auto& stats_aggregator = ctx->stats_aggregator();
+      if (stats_aggregator) {
+        stats_aggregator->AddScalar(
+            stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
+            static_cast<float>(num_calls_) /
+                static_cast<float>(num_parallel_calls_->value),
+            num_elements());
+      }
       cond_var_->notify_all();
     }
 
-    void CallFunction(std::shared_ptr<IteratorContext> ctx, BatchResult* result,
+    void CallFunction(std::shared_ptr<IteratorContext> ctx,
+                      const std::shared_ptr<BatchResult>& result,
                       int64_t offset) TF_LOCKS_EXCLUDED(*mu_) {
       profiler::TraceMe traceme([&] {
         return profiler::TraceMeEncode("MapAndBatchProduce",
@@ -407,7 +415,7 @@
         return_early = result->end_of_input || !result->status.ok();
       }
       if (return_early) {
-        CallCompleted(result);
+        CallCompleted(ctx, result);
         return;
       }
 
@@ -425,7 +433,7 @@
         result->UpdateStatus(status, offset);
         if (status.ok()) {
           Status allocate_status =
-              EnsureOutputAllocated(ctx.get(), *return_values, result);
+              EnsureOutputAllocated(ctx, result, return_values);
           if (!allocate_status.ok()) {
             result->UpdateStatus(allocate_status, offset);
           } else {
@@ -461,7 +469,7 @@
             result->num_elements++;
           }
         }
-        CallCompleted(result);
+        CallCompleted(ctx, result);
       };
 
       // Apply the map function on `input_element`, storing the result in
@@ -485,53 +493,47 @@
     void EnsureRunnerThreadStarted(IteratorContext* ctx)
         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
       if (!runner_thread_) {
+        auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
         runner_thread_ = ctx->StartThread(
-            "tf_data_map_and_batch",
-            [this, ctx = std::make_shared<IteratorContext>(*ctx)]() {
-              RunnerThread(ctx);
-            });
-        if (ctx->stats_aggregator()) {
-          stats_thread_ = ctx->StartThread(
-              "tf_data_map_and_batch_stats",
-              [this, ctx = std::make_shared<IteratorContext>(*ctx)]() {
-                StatsThread(ctx.get());
-              });
-        }
+            kTFDataMapAndBatch,
+            std::bind(&Iterator::RunnerThread, this, ctx_copy));
       }
     }
 
-    Status EnsureOutputAllocated(IteratorContext* ctx,
-                                 const std::vector<Tensor>& return_values,
-                                 BatchResult* result) {
+    Status EnsureOutputAllocated(
+        const std::shared_ptr<IteratorContext>& ctx,
+        const std::shared_ptr<BatchResult>& result,
+        const std::shared_ptr<std::vector<Tensor>>& return_values) {
       mutex_lock l(result->mu);
       if (result->output_allocated) {
         return Status::OK();
       }
-      const size_t num_components = return_values.size();
+      const size_t num_components = return_values->size();
       result->output.reserve(num_components);
       for (size_t i = 0; i < num_components; ++i) {
         TensorShape component_shape({dataset()->batch_size_});
-        component_shape.AppendShape(return_values.at(i).shape());
+        component_shape.AppendShape(return_values->at(i).shape());
         AllocatorAttributes attr;
         attr.set_gpu_compatible(true);
-        result->output.emplace_back(
-            ctx->allocator(attr), return_values.at(i).dtype(), component_shape);
+        result->output.emplace_back(ctx->allocator(attr),
+                                    return_values->at(i).dtype(),
+                                    component_shape);
         if (!result->output.back().IsInitialized()) {
           return errors::ResourceExhausted(
               "Failed to allocate memory for the batch of component ", i);
         }
       }
-      RecordBufferEnqueue(ctx, result->output);
+      RecordBufferEnqueue(ctx.get(), result->output);
       result->output_allocated = true;
       return Status::OK();
     }
 
-    void RunnerThread(std::shared_ptr<IteratorContext> ctx)
+    void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
         TF_LOCKS_EXCLUDED(*mu_) {
       std::vector<std::pair<std::shared_ptr<BatchResult>, int64_t>> new_calls;
       RecordStart(ctx.get());
       auto stop_cleanup =
-          gtl::MakeCleanup([this, ctx]() { RecordStop(ctx.get()); });
+          gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
       {
         tf_shared_lock l(*mu_);  // mu_ == num_parallel_calls_->mu
         new_calls.reserve(num_parallel_calls_->value);
@@ -575,41 +577,22 @@
             num_calls_++;
           }
         }
+        const auto& stats_aggregator = ctx->stats_aggregator();
+        if (stats_aggregator) {
+          mutex_lock l(*mu_);
+          stats_aggregator->AddScalar(
+              stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
+              static_cast<float>(num_calls_) /
+                  static_cast<float>(num_parallel_calls_->value),
+              num_elements());
+        }
         for (const auto& call : new_calls) {
-          CallFunction(ctx, call.first.get(), call.second);
+          CallFunction(ctx, call.first, call.second);
         }
         new_calls.clear();
       }
     }
 
-    void StatsThread(IteratorContext* ctx) {
-      for (int64_t step = 0;; ++step) {
-        int num_calls;
-        int num_parallel_calls;
-        {
-          mutex_lock l(*mu_);
-          if (step != 0 && !cancelled_) {
-            cond_var_->wait_for(
-                l, std::chrono::milliseconds(kStatsReportingPeriodMillis));
-          }
-          if (cancelled_) {
-            return;
-          }
-          num_calls = num_calls_;
-          num_parallel_calls = num_parallel_calls_->value;
-        }
-        if (num_parallel_calls == 0) {
-          // Avoid division by zero.
-          num_parallel_calls = 1;
-        }
-        ctx->stats_aggregator()->AddScalar(
-            stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
-            static_cast<float>(num_calls) /
-                static_cast<float>(num_parallel_calls),
-            step);
-      }
-    }
-
     Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
                            size_t index) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
       batch_results_.push_back(
@@ -696,8 +679,6 @@
     std::deque<std::shared_ptr<BatchResult>> batch_results_ TF_GUARDED_BY(*mu_);
     // Background thread used for coordinating input processing.
     std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
-    // Background thread used for collecting statistics.
-    std::unique_ptr<Thread> stats_thread_ TF_GUARDED_BY(*mu_);
     // Determines whether the transformation has been cancelled.
     bool cancelled_ TF_GUARDED_BY(*mu_) = false;
     // Identifies the number of callers currently waiting for a batch result.
diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc
index 665e9b0..68feac3 100644
--- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc
@@ -564,11 +564,10 @@
       if (reader->Contains(prefix(), kWorkerThreadsRunning)) {
         worker_threads_.reserve(dataset()->num_threads());
         for (size_t i = 0; i < dataset()->num_threads(); ++i) {
+          std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
           worker_threads_.emplace_back(ctx->StartThread(
               strings::StrCat(kDataParallelInterleaveWorker, "_", i),
-              [this, ctx = std::make_shared<IteratorContext>(*ctx), i]() {
-                WorkerThread(ctx.get(), i);
-              }));
+              [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
         }
       }
       return Status::OK();
@@ -679,11 +678,10 @@
             return Status::OK();
           }
           workers_[i].SetInputs(s, std::move(args));
+          std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
           worker_threads_.push_back(ctx->StartThread(
               strings::StrCat(kDataParallelInterleaveWorker, "_", i),
-              [this, ctx = std::make_shared<IteratorContext>(*ctx), i]() {
-                WorkerThread(ctx.get(), i);
-              }));
+              [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
           if (i < dataset()->cycle_length_) {
             interleave_indices_.push_back(i);
           } else {
@@ -697,7 +695,8 @@
     }
 
     // Produces elements into the worker's output buffers.
-    void WorkerThread(IteratorContext* ctx, int64_t thread_index) {
+    void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
+                      const int64_t thread_index) {
       // Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
       //
       // 1. Any local state that may need to be checkpointed should be kept
@@ -718,11 +717,11 @@
 
       // std::function arguments are copy-constructable, so we pass raw
       // pointers, and then immediately wrap them to ensure correct ownership.
-      RecordStart(ctx);
+      RecordStart(ctx.get());
       auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
         mutex_lock l(mu_);
         workers_[thread_index].cond_var.notify_all();
-        RecordStop(ctx);
+        RecordStop(ctx.get());
       });
       bool make_new_iterator;
       {
@@ -759,9 +758,9 @@
           if (read_new_input) {
             mutex_lock l(mu_);
             while (!cancelled_ && !workers_[thread_index].is_producing) {
-              RecordStop(ctx);
+              RecordStop(ctx.get());
               workers_[thread_index].cond_var.wait(l);
-              RecordStart(ctx);
+              RecordStart(ctx.get());
             }
             if (cancelled_) return;
             // Copy the input tensors so that we do not need to block on `mu_`
@@ -783,7 +782,7 @@
             tf_shared_lock l(ckpt_mu_);
             worker_thread_states_[thread_index].iterator_creation_status =
                 MakeIteratorFromInputElement(
-                    ctx, this, worker_thread_states_[thread_index].input,
+                    ctx.get(), this, worker_thread_states_[thread_index].input,
                     thread_index, *instantiated_captured_func_, prefix(),
                     &worker_thread_states_[thread_index].iterator,
                     model_node());
@@ -812,9 +811,9 @@
           // Wait for space in the prefetch queue.
           while (!cancelled_ && workers_[thread_index].outputs.size() ==
                                     dataset()->buffer_output_elements_) {
-            RecordStop(ctx);
+            RecordStop(ctx.get());
             workers_[thread_index].cond_var.wait(l);
-            RecordStart(ctx);
+            RecordStart(ctx.get());
           }
           if (cancelled_) return;
           tf_shared_lock ckpt_l(ckpt_mu_);
@@ -851,7 +850,7 @@
                     profiler::kInfo);
                 worker_thread_states_[thread_index].output_elem.status =
                     worker_thread_states_[thread_index].iterator->GetNext(
-                        ctx,
+                        ctx.get(),
                         &worker_thread_states_[thread_index].output_elem.output,
                         &worker_thread_states_[thread_index].end_of_sequence);
                 end_of_sequence =
@@ -873,9 +872,9 @@
               // Wait for space in the prefetch queue.
               while (!cancelled_ && workers_[thread_index].outputs.size() ==
                                         dataset()->buffer_output_elements_) {
-                RecordStop(ctx);
+                RecordStop(ctx.get());
                 workers_[thread_index].cond_var.wait(l);
-                RecordStart(ctx);
+                RecordStart(ctx.get());
               }
               if (cancelled_) return;
 
diff --git a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc
index a894565..897f1c4 100644
--- a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc
@@ -412,7 +412,7 @@
           return profiler::TraceMeEncode("ParseExampleConsume",
                                          {{"element_id", result->id}});
         });
-        return ProcessResult(ctx, result.get(), out_tensors, end_of_sequence);
+        return ProcessResult(ctx, result, out_tensors, end_of_sequence);
       }
 
      protected:
@@ -554,32 +554,31 @@
       void EnsureThreadsStarted(IteratorContext* ctx)
           TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
         if (!runner_thread_) {
+          auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
           runner_thread_ = ctx->StartThread(
-              "tf_data_parse_example",
-              [this, ctx = std::make_shared<IteratorContext>(*ctx)]() {
-                RunnerThread(ctx);
-              });
+              "tf_data_parallel_map",
+              std::bind(&Iterator::RunnerThread, this, ctx_copy));
           if (ctx->stats_aggregator()) {
             stats_thread_ = ctx->StartThread(
-                "tf_data_parse_example_stats",
-                [this, ctx = std::make_shared<IteratorContext>(*ctx)]() {
-                  StatsThread(ctx.get());
-                });
+                "tf_data_parallel_map_stats",
+                std::bind(&Iterator::StatsThread, this, ctx_copy));
           }
         }
       }
 
-      void CallCompleted(IteratorContext* ctx, InvocationResult* result)
+      void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
+                         const std::shared_ptr<InvocationResult>& result)
           TF_LOCKS_EXCLUDED(*mu_) {
         mutex_lock l(*mu_);
         num_calls_--;
-        RecordBufferEnqueue(ctx, result->return_values);
+        RecordBufferEnqueue(ctx.get(), result->return_values);
         result->notification.Notify();
         cond_var_->notify_all();
       }
 
-      void CallFunction(std::shared_ptr<IteratorContext> ctx,
-                        InvocationResult* result) TF_LOCKS_EXCLUDED(*mu_) {
+      void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
+                        const std::shared_ptr<InvocationResult>& result)
+          TF_LOCKS_EXCLUDED(*mu_) {
         profiler::TraceMe traceme([&] {
           return profiler::TraceMeEncode("ParseExampleProduce",
                                          {{"element_id", result->id}});
@@ -589,13 +588,13 @@
         result->status = input_impl_->GetNext(ctx.get(), &input_element,
                                               &result->end_of_input);
         if (result->end_of_input || !result->status.ok()) {
-          CallCompleted(ctx.get(), result);
+          CallCompleted(ctx, result);
           return;
         }
 
         auto done = [this, ctx, result](Status status) {
           result->status.Update(status);
-          CallCompleted(ctx.get(), result);
+          CallCompleted(ctx, result);
         };
 
         // We schedule the `ParseExample` function using `ctx->runner()` to
@@ -724,7 +723,8 @@
         return Status::OK();
       }
 
-      Status ProcessResult(IteratorContext* ctx, InvocationResult* result,
+      Status ProcessResult(IteratorContext* ctx,
+                           const std::shared_ptr<InvocationResult>& result,
                            std::vector<Tensor>* out_tensors,
                            bool* end_of_sequence) TF_LOCKS_EXCLUDED(*mu_) {
         if (!result->end_of_input && result->status.ok()) {
@@ -745,7 +745,7 @@
         return result->status;
       }
 
-      void RunnerThread(std::shared_ptr<IteratorContext> ctx)
+      void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
           TF_LOCKS_EXCLUDED(*mu_) {
         RecordStart(ctx.get());
         auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
@@ -781,7 +781,7 @@
             cond_var_->notify_all();
           }
           for (const auto& call : new_calls) {
-            CallFunction(ctx, call.get());
+            CallFunction(ctx, call);
           }
           new_calls.clear();
         }
@@ -819,7 +819,7 @@
         return true;
       }
 
-      void StatsThread(IteratorContext* ctx) {
+      void StatsThread(const std::shared_ptr<IteratorContext>& ctx) {
         for (int64_t step = 0;; ++step) {
           int num_calls;
           int num_parallel_calls;
diff --git a/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc b/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc
index b0d7ab8..46be263 100644
--- a/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc
@@ -337,7 +337,9 @@
       const int64_t uid = -1;
     };
 
-    void CallCompleted(BatchResult* result) TF_LOCKS_EXCLUDED(*mu_) {
+    void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
+                       const std::shared_ptr<BatchResult>& result)
+        TF_LOCKS_EXCLUDED(*mu_) {
       mutex_lock l(*mu_);
       num_calls_--;
       result->call_finished = true;
@@ -347,7 +349,8 @@
     // The function fetches elements from input dataset sequentially and then
     // executes the batching for different batches in parallel using the context
     // runner.
-    void CallBatching(std::shared_ptr<IteratorContext> ctx, BatchResult* result)
+    void CallBatching(std::shared_ptr<IteratorContext> ctx,
+                      const std::shared_ptr<BatchResult>& result)
         TF_LOCKS_EXCLUDED(*mu_) {
       profiler::TraceMe traceme([&] {
         return profiler::TraceMeEncode("ParallelBatchProduce",
@@ -355,7 +358,7 @@
       });
 
       if (!input_impl_) {
-        CallCompleted(result);
+        CallCompleted(ctx, result);
         return;
       }
 
@@ -386,7 +389,7 @@
       }
 
       if (batch_elements->empty()) {
-        CallCompleted(result);
+        CallCompleted(ctx, result);
         return;
       }
 
@@ -406,7 +409,7 @@
                              std::move(allocation_callback), &result->output);
           result->status.Update(status);
         }
-        CallCompleted(result);
+        CallCompleted(ctx, result);
         return status;
       };
 
@@ -427,15 +430,14 @@
     void EnsureRunnerThreadStarted(IteratorContext* ctx)
         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
       if (!runner_thread_) {
+        auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
         runner_thread_ = ctx->StartThread(
-            "tf_data_parallel_batch",
-            [this, ctx = std::make_shared<IteratorContext>(*ctx)]() {
-              RunnerThread(ctx);
-            });
+            kTFDataParallelBatch,
+            std::bind(&Iterator::RunnerThread, this, ctx_copy));
       }
     }
 
-    void RunnerThread(std::shared_ptr<IteratorContext> ctx)
+    void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
         TF_LOCKS_EXCLUDED(*mu_) {
       std::vector<std::shared_ptr<BatchResult>> new_calls;
       RecordStart(ctx.get());
@@ -470,7 +472,7 @@
           }
         }
         for (const auto& call : new_calls) {
-          CallBatching(ctx, call.get());
+          CallBatching(ctx, call);
         }
         new_calls.clear();
       }
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index b67491b..7a5260a 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -257,7 +257,7 @@
         return profiler::TraceMeEncode("ParallelMapConsume",
                                        {{"element_id", result->uid}});
       });
-      return ProcessResult(ctx, result.get(), out_tensors, end_of_sequence);
+      return ProcessResult(ctx, result, out_tensors, end_of_sequence);
     }
 
    protected:
@@ -395,30 +395,30 @@
     void EnsureThreadsStarted(IteratorContext* ctx)
         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
       if (!runner_thread_) {
+        auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
         runner_thread_ = ctx->StartThread(
             "tf_data_parallel_map",
-            [this, ctx = std::make_shared<IteratorContext>(*ctx)]() {
-              RunnerThread(ctx);
-            });
+            std::bind(&Iterator::RunnerThread, this, ctx_copy));
         if (ctx->stats_aggregator()) {
           stats_thread_ = ctx->StartThread(
               "tf_data_parallel_map_stats",
-              [this, ctx = std::make_shared<IteratorContext>(*ctx)]() {
-                StatsThread(ctx.get());
-              });
+              std::bind(&Iterator::StatsThread, this, ctx_copy));
         }
       }
     }
 
-    void CallCompleted(InvocationResult* result) TF_LOCKS_EXCLUDED(*mu_) {
+    void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
+                       const std::shared_ptr<InvocationResult>& result)
+        TF_LOCKS_EXCLUDED(*mu_) {
       mutex_lock l(*mu_);
       num_calls_--;
       result->notification.Notify();
       cond_var_->notify_all();
     }
 
-    void CallFunction(std::shared_ptr<IteratorContext> ctx,
-                      InvocationResult* result) TF_LOCKS_EXCLUDED(*mu_) {
+    void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
+                      const std::shared_ptr<InvocationResult>& result)
+        TF_LOCKS_EXCLUDED(*mu_) {
       profiler::TraceMe traceme([&] {
         return profiler::TraceMeEncode("ParallelMapProduce",
                                        {{"element_id", result->uid}});
@@ -428,14 +428,14 @@
       result->status = input_impl_->GetNext(ctx.get(), &input_element,
                                             &result->end_of_input);
       if (result->end_of_input || !result->status.ok()) {
-        CallCompleted(result);
+        CallCompleted(ctx, result);
         return;
       }
 
       auto done = [this, ctx, result](Status status) {
         result->status.Update(status);
         RecordBufferEnqueue(ctx.get(), result->return_values);
-        CallCompleted(result);
+        CallCompleted(ctx, result);
       };
 
       // Apply the map function on `input_element`, storing the result in
@@ -472,7 +472,8 @@
       }
     }
 
-    Status ProcessResult(IteratorContext* ctx, InvocationResult* result,
+    Status ProcessResult(IteratorContext* ctx,
+                         const std::shared_ptr<InvocationResult>& result,
                          std::vector<Tensor>* out_tensors,
                          bool* end_of_sequence) TF_LOCKS_EXCLUDED(*mu_) {
       if (!result->end_of_input && result->status.ok()) {
@@ -500,7 +501,7 @@
       return result->status;
     }
 
-    void RunnerThread(std::shared_ptr<IteratorContext> ctx)
+    void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
         TF_LOCKS_EXCLUDED(*mu_) {
       RecordStart(ctx.get());
       auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
@@ -533,13 +534,13 @@
           cond_var_->notify_all();
         }
         for (const auto& call : new_calls) {
-          CallFunction(ctx, call.get());
+          CallFunction(ctx, call);
         }
         new_calls.clear();
       }
     }
 
-    // Determines whether the caller needs to wait for a result-> Upon returning
+    // Determines whether the caller needs to wait for a result. Upon returning
     // false, `result` will point to the result.
     bool ShouldWait(std::shared_ptr<InvocationResult>* result)
         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
@@ -571,7 +572,7 @@
       return true;
     }
 
-    void StatsThread(IteratorContext* ctx) {
+    void StatsThread(const std::shared_ptr<IteratorContext>& ctx) {
       for (int64_t step = 0;; ++step) {
         int num_calls;
         int num_parallel_calls;
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 0be61ef..ba15c60 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -438,11 +438,10 @@
     Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
       if (!prefetch_thread_) {
+        std::shared_ptr<IteratorContext> new_ctx =
+            std::make_shared<IteratorContext>(*ctx);
         prefetch_thread_ = ctx->StartThread(
-            "tf_data_prefetch",
-            [this, ctx = std::make_shared<IteratorContext>(*ctx)]() {
-              PrefetchThread(ctx.get());
-            });
+            "tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); });
       }
       return Status::OK();
     }
@@ -450,9 +449,9 @@
     // Prefetches elements of the input, storing results in an internal buffer.
     //
     // It owns the iterator context passed to it.
-    void PrefetchThread(IteratorContext* ctx) {
-      RecordStart(ctx);
-      auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx); });
+    void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
+      RecordStart(ctx.get());
+      auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
       // Keep track of where we are in an iteration "burst"
       int num_produced = 0;
       while (true) {
@@ -460,9 +459,9 @@
         {
           mutex_lock l(*mu_);
           while (!cancelled_ && buffer_.size() >= buffer_limit()) {
-            RecordStop(ctx);
+            RecordStop(ctx.get());
             cond_var_->wait(l);
-            RecordStart(ctx);
+            RecordStart(ctx.get());
           }
 
           if (cancelled_) {
@@ -496,7 +495,7 @@
               },
               profiler::kInfo);
           buffer_element.status = input_impl_->GetNext(
-              ctx, &buffer_element.value, &end_of_sequence);
+              ctx.get(), &buffer_element.value, &end_of_sequence);
         }
         if (buffer_element.status.ok() && end_of_sequence) {
           mutex_lock l(*mu_);
@@ -508,7 +507,7 @@
         // 3. Signal that the element has been produced.
         {
           mutex_lock l(*mu_);
-          RecordBufferEnqueue(ctx, buffer_element.value);
+          RecordBufferEnqueue(ctx.get(), buffer_element.value);
           buffer_element.created_us = EnvTime::NowMicros();
           buffer_.push_back(std::move(buffer_element));
           cond_var_->notify_all();