[tf.data] Improves consistency of IteratorContext life cycle management. After this change, background threads of asynchronous op kernels own a copy of IteratorContext.
PiperOrigin-RevId: 396939365
Change-Id: Ie0ffff34a808036e5d60fbe980a28ac757e0eed1
diff --git a/tensorflow/core/data/captured_function.cc b/tensorflow/core/data/captured_function.cc
index 8272f5d..1d60519 100644
--- a/tensorflow/core/data/captured_function.cc
+++ b/tensorflow/core/data/captured_function.cc
@@ -389,11 +389,11 @@
const std::vector<Tensor>& input_element, int64_t thread_index,
const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator,
- std::shared_ptr<model::Node> node) {
+ const std::shared_ptr<model::Node>& node) {
std::vector<Tensor> return_values;
TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(
- ctx, input_element, &return_values, std::move(node)));
+ ctx, input_element, &return_values, node));
if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
TensorShapeUtils::IsScalar(return_values[0].shape()))) {
@@ -773,7 +773,7 @@
Status InstantiatedCapturedFunction::Run(
IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
- std::shared_ptr<model::Node> node) const {
+ const std::shared_ptr<model::Node>& node) const {
auto& info = captured_func_->short_circuit_info();
if (!info.indices.empty()) {
return RunShortCircuit(info, std::move(args), captured_func_, rets);
@@ -836,7 +836,7 @@
Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
IteratorContext* ctx, const std::vector<Tensor>& args,
- std::vector<Tensor>* rets, std::shared_ptr<model::Node> node) const {
+ std::vector<Tensor>* rets, const std::shared_ptr<model::Node>& node) const {
auto& info = captured_func_->short_circuit_info();
if (!info.indices.empty()) {
return RunShortCircuit(info, args, captured_func_, rets);
@@ -921,13 +921,10 @@
return frame.ConsumeRetvals(rets);
}
-// NOTE: The `done` callback will be invoked asynchronously from the calling
-// thread. The caller is therefore responsible for making sure that any objects
-// accessed by the callback exist at least until the callback returns.
void InstantiatedCapturedFunction::RunAsync(
IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done,
- std::shared_ptr<model::Node> node) const {
+ const std::shared_ptr<model::Node>& node) const {
auto& info = captured_func_->short_circuit_info();
if (!info.indices.empty()) {
// Run the `done` callback on a threadpool thread, because it will
@@ -940,6 +937,9 @@
return;
}
+ // NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
+ // be deleted before `done` is called. Take care not to capture `ctx` in any
+ // code that may execute asynchronously in this function.
OwnedArgsCallFrame* frame = new OwnedArgsCallFrame(
std::move(args), &captured_func_->captured_inputs(), ret_types_);
@@ -953,7 +953,7 @@
f_opts.runner = ctx->runner();
f_opts.create_rendezvous = ShouldCreateRendezvous();
auto cancellation_manager =
- std::make_shared<CancellationManager>(ctx->cancellation_manager());
+ absl::make_unique<CancellationManager>(ctx->cancellation_manager());
f_opts.cancellation_manager = cancellation_manager.get();
f_opts.collective_executor = ctx->collective_executor();
@@ -964,13 +964,19 @@
const bool collect_usage = node && ctx->model();
f_opts.stats_collector = stats_collector.get();
- // Transferring ownership of `step_container` and `frame` into `callback`.
- auto callback =
- [this, stats_collector = std::move(stats_collector),
- stats_aggregator = ctx->stats_aggregator(), done = std::move(done),
- cancellation_manager = std::move(cancellation_manager), node, rets,
- step_container, frame, collect_usage](Status s) {
+ // Transfer ownership of the cancellation manager to `callback`.
+ CancellationManager* raw_cancellation_manager =
+ cancellation_manager.release();
+ auto callback = std::bind(
+ [this, rets, step_container, raw_cancellation_manager, frame, node,
+ collect_usage](
+ const FunctionLibraryRuntime::DoneCallback& done,
+ IteratorContext* ctx,
+ const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
+ // Begin unbound arguments.
+ Status s) {
delete step_container;
+ delete raw_cancellation_manager;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
}
@@ -979,11 +985,11 @@
// TODO(b/129085499) Utilize the `node_name` which would be unique
// than the prefix for the function execution time statistics.
// prefix_with_func_name would then be node_name + func_name.
- if (stats_aggregator) {
+ if (ctx->stats_aggregator()) {
string prefix_with_func_name =
strings::StrCat(node->name(), stats_utils::kDelimiter,
captured_func_->func().name());
- stats_aggregator->AddToHistogram(
+ ctx->stats_aggregator()->AddToHistogram(
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
{static_cast<float>(stats_collector->processing_time())},
node->num_elements());
@@ -997,7 +1003,8 @@
if (collect_usage) {
node->record_stop(EnvTime::NowNanos());
}
- };
+ },
+ std::move(done), ctx, std::move(stats_collector), std::placeholders::_1);
profiler::TraceMe activity(
[&] {
diff --git a/tensorflow/core/data/captured_function.h b/tensorflow/core/data/captured_function.h
index afa197f..82d2b1a 100644
--- a/tensorflow/core/data/captured_function.h
+++ b/tensorflow/core/data/captured_function.h
@@ -56,7 +56,7 @@
const std::vector<Tensor>& input_element, int64_t thread_index,
const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator,
- std::shared_ptr<model::Node> node);
+ const std::shared_ptr<model::Node>& node);
struct ShortCircuitInfo {
std::vector<int> indices;
@@ -246,7 +246,7 @@
// called `DatasetBaseIterator::RecordStart().
Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
- std::shared_ptr<model::Node> node) const;
+ const std::shared_ptr<model::Node>& node) const;
// Synchronously runs the captured function on the given `args`, and stores
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
@@ -264,7 +264,7 @@
Status RunWithBorrowedArgs(IteratorContext* ctx,
const std::vector<Tensor>& args,
std::vector<Tensor>* rets,
- std::shared_ptr<model::Node> node) const;
+ const std::shared_ptr<model::Node>& node) const;
// Synchronously runs the captured function on the given `args`, and stores
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
@@ -285,7 +285,7 @@
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done,
- std::shared_ptr<model::Node> node) const;
+ const std::shared_ptr<model::Node>& node) const;
private:
friend class CapturedFunction;
diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc
index c323b0a..fe71f21 100644
--- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc
@@ -77,9 +77,6 @@
// Computes ceil(x / y).
inline int64_t CeilDiv(int64_t x, int64_t y) { return (x + y - 1) / y; }
-// Period between reporting dataset statistics.
-constexpr int kStatsReportingPeriodMillis = 1000;
-
} // namespace
class MapAndBatchDatasetOp::Dataset : public DatasetBase {
@@ -381,14 +378,25 @@
const uint64 uid = -1;
};
- void CallCompleted(BatchResult* result) TF_LOCKS_EXCLUDED(*mu_) {
+ void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
+ const std::shared_ptr<BatchResult>& result)
+ TF_LOCKS_EXCLUDED(*mu_) {
mutex_lock l(*mu_);
num_calls_--;
result->num_calls--;
+ const auto& stats_aggregator = ctx->stats_aggregator();
+ if (stats_aggregator) {
+ stats_aggregator->AddScalar(
+ stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
+ static_cast<float>(num_calls_) /
+ static_cast<float>(num_parallel_calls_->value),
+ num_elements());
+ }
cond_var_->notify_all();
}
- void CallFunction(std::shared_ptr<IteratorContext> ctx, BatchResult* result,
+ void CallFunction(std::shared_ptr<IteratorContext> ctx,
+ const std::shared_ptr<BatchResult>& result,
int64_t offset) TF_LOCKS_EXCLUDED(*mu_) {
profiler::TraceMe traceme([&] {
return profiler::TraceMeEncode("MapAndBatchProduce",
@@ -407,7 +415,7 @@
return_early = result->end_of_input || !result->status.ok();
}
if (return_early) {
- CallCompleted(result);
+ CallCompleted(ctx, result);
return;
}
@@ -425,7 +433,7 @@
result->UpdateStatus(status, offset);
if (status.ok()) {
Status allocate_status =
- EnsureOutputAllocated(ctx.get(), *return_values, result);
+ EnsureOutputAllocated(ctx, result, return_values);
if (!allocate_status.ok()) {
result->UpdateStatus(allocate_status, offset);
} else {
@@ -461,7 +469,7 @@
result->num_elements++;
}
}
- CallCompleted(result);
+ CallCompleted(ctx, result);
};
// Apply the map function on `input_element`, storing the result in
@@ -485,53 +493,47 @@
void EnsureRunnerThreadStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
+ auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_ = ctx->StartThread(
- "tf_data_map_and_batch",
- [this, ctx = std::make_shared<IteratorContext>(*ctx)]() {
- RunnerThread(ctx);
- });
- if (ctx->stats_aggregator()) {
- stats_thread_ = ctx->StartThread(
- "tf_data_map_and_batch_stats",
- [this, ctx = std::make_shared<IteratorContext>(*ctx)]() {
- StatsThread(ctx.get());
- });
- }
+ kTFDataMapAndBatch,
+ std::bind(&Iterator::RunnerThread, this, ctx_copy));
}
}
- Status EnsureOutputAllocated(IteratorContext* ctx,
- const std::vector<Tensor>& return_values,
- BatchResult* result) {
+ Status EnsureOutputAllocated(
+ const std::shared_ptr<IteratorContext>& ctx,
+ const std::shared_ptr<BatchResult>& result,
+ const std::shared_ptr<std::vector<Tensor>>& return_values) {
mutex_lock l(result->mu);
if (result->output_allocated) {
return Status::OK();
}
- const size_t num_components = return_values.size();
+ const size_t num_components = return_values->size();
result->output.reserve(num_components);
for (size_t i = 0; i < num_components; ++i) {
TensorShape component_shape({dataset()->batch_size_});
- component_shape.AppendShape(return_values.at(i).shape());
+ component_shape.AppendShape(return_values->at(i).shape());
AllocatorAttributes attr;
attr.set_gpu_compatible(true);
- result->output.emplace_back(
- ctx->allocator(attr), return_values.at(i).dtype(), component_shape);
+ result->output.emplace_back(ctx->allocator(attr),
+ return_values->at(i).dtype(),
+ component_shape);
if (!result->output.back().IsInitialized()) {
return errors::ResourceExhausted(
"Failed to allocate memory for the batch of component ", i);
}
}
- RecordBufferEnqueue(ctx, result->output);
+ RecordBufferEnqueue(ctx.get(), result->output);
result->output_allocated = true;
return Status::OK();
}
- void RunnerThread(std::shared_ptr<IteratorContext> ctx)
+ void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
TF_LOCKS_EXCLUDED(*mu_) {
std::vector<std::pair<std::shared_ptr<BatchResult>, int64_t>> new_calls;
RecordStart(ctx.get());
auto stop_cleanup =
- gtl::MakeCleanup([this, ctx]() { RecordStop(ctx.get()); });
+ gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
{
tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu
new_calls.reserve(num_parallel_calls_->value);
@@ -575,41 +577,22 @@
num_calls_++;
}
}
+ const auto& stats_aggregator = ctx->stats_aggregator();
+ if (stats_aggregator) {
+ mutex_lock l(*mu_);
+ stats_aggregator->AddScalar(
+ stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
+ static_cast<float>(num_calls_) /
+ static_cast<float>(num_parallel_calls_->value),
+ num_elements());
+ }
for (const auto& call : new_calls) {
- CallFunction(ctx, call.first.get(), call.second);
+ CallFunction(ctx, call.first, call.second);
}
new_calls.clear();
}
}
- void StatsThread(IteratorContext* ctx) {
- for (int64_t step = 0;; ++step) {
- int num_calls;
- int num_parallel_calls;
- {
- mutex_lock l(*mu_);
- if (step != 0 && !cancelled_) {
- cond_var_->wait_for(
- l, std::chrono::milliseconds(kStatsReportingPeriodMillis));
- }
- if (cancelled_) {
- return;
- }
- num_calls = num_calls_;
- num_parallel_calls = num_parallel_calls_->value;
- }
- if (num_parallel_calls == 0) {
- // Avoid division by zero.
- num_parallel_calls = 1;
- }
- ctx->stats_aggregator()->AddScalar(
- stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
- static_cast<float>(num_calls) /
- static_cast<float>(num_parallel_calls),
- step);
- }
- }
-
Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
size_t index) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
batch_results_.push_back(
@@ -696,8 +679,6 @@
std::deque<std::shared_ptr<BatchResult>> batch_results_ TF_GUARDED_BY(*mu_);
// Background thread used for coordinating input processing.
std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
- // Background thread used for collecting statistics.
- std::unique_ptr<Thread> stats_thread_ TF_GUARDED_BY(*mu_);
// Determines whether the transformation has been cancelled.
bool cancelled_ TF_GUARDED_BY(*mu_) = false;
// Identifies the number of callers currently waiting for a batch result.
diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc
index 665e9b0..68feac3 100644
--- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc
@@ -564,11 +564,10 @@
if (reader->Contains(prefix(), kWorkerThreadsRunning)) {
worker_threads_.reserve(dataset()->num_threads());
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->StartThread(
strings::StrCat(kDataParallelInterleaveWorker, "_", i),
- [this, ctx = std::make_shared<IteratorContext>(*ctx), i]() {
- WorkerThread(ctx.get(), i);
- }));
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
}
}
return Status::OK();
@@ -679,11 +678,10 @@
return Status::OK();
}
workers_[i].SetInputs(s, std::move(args));
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.push_back(ctx->StartThread(
strings::StrCat(kDataParallelInterleaveWorker, "_", i),
- [this, ctx = std::make_shared<IteratorContext>(*ctx), i]() {
- WorkerThread(ctx.get(), i);
- }));
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
if (i < dataset()->cycle_length_) {
interleave_indices_.push_back(i);
} else {
@@ -697,7 +695,8 @@
}
// Produces elements into the worker's output buffers.
- void WorkerThread(IteratorContext* ctx, int64_t thread_index) {
+ void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
+ const int64_t thread_index) {
// Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
//
// 1. Any local state that may need to be checkpointed should be kept
@@ -718,11 +717,11 @@
// std::function arguments are copy-constructable, so we pass raw
// pointers, and then immediately wrap them to ensure correct ownership.
- RecordStart(ctx);
+ RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
mutex_lock l(mu_);
workers_[thread_index].cond_var.notify_all();
- RecordStop(ctx);
+ RecordStop(ctx.get());
});
bool make_new_iterator;
{
@@ -759,9 +758,9 @@
if (read_new_input) {
mutex_lock l(mu_);
while (!cancelled_ && !workers_[thread_index].is_producing) {
- RecordStop(ctx);
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
- RecordStart(ctx);
+ RecordStart(ctx.get());
}
if (cancelled_) return;
// Copy the input tensors so that we do not need to block on `mu_`
@@ -783,7 +782,7 @@
tf_shared_lock l(ckpt_mu_);
worker_thread_states_[thread_index].iterator_creation_status =
MakeIteratorFromInputElement(
- ctx, this, worker_thread_states_[thread_index].input,
+ ctx.get(), this, worker_thread_states_[thread_index].input,
thread_index, *instantiated_captured_func_, prefix(),
&worker_thread_states_[thread_index].iterator,
model_node());
@@ -812,9 +811,9 @@
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
- RecordStop(ctx);
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
- RecordStart(ctx);
+ RecordStart(ctx.get());
}
if (cancelled_) return;
tf_shared_lock ckpt_l(ckpt_mu_);
@@ -851,7 +850,7 @@
profiler::kInfo);
worker_thread_states_[thread_index].output_elem.status =
worker_thread_states_[thread_index].iterator->GetNext(
- ctx,
+ ctx.get(),
&worker_thread_states_[thread_index].output_elem.output,
&worker_thread_states_[thread_index].end_of_sequence);
end_of_sequence =
@@ -873,9 +872,9 @@
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
- RecordStop(ctx);
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
- RecordStart(ctx);
+ RecordStart(ctx.get());
}
if (cancelled_) return;
diff --git a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc
index a894565..897f1c4 100644
--- a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc
@@ -412,7 +412,7 @@
return profiler::TraceMeEncode("ParseExampleConsume",
{{"element_id", result->id}});
});
- return ProcessResult(ctx, result.get(), out_tensors, end_of_sequence);
+ return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
protected:
@@ -554,32 +554,31 @@
void EnsureThreadsStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
+ auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_ = ctx->StartThread(
- "tf_data_parse_example",
- [this, ctx = std::make_shared<IteratorContext>(*ctx)]() {
- RunnerThread(ctx);
- });
+ "tf_data_parallel_map",
+ std::bind(&Iterator::RunnerThread, this, ctx_copy));
if (ctx->stats_aggregator()) {
stats_thread_ = ctx->StartThread(
- "tf_data_parse_example_stats",
- [this, ctx = std::make_shared<IteratorContext>(*ctx)]() {
- StatsThread(ctx.get());
- });
+ "tf_data_parallel_map_stats",
+ std::bind(&Iterator::StatsThread, this, ctx_copy));
}
}
}
- void CallCompleted(IteratorContext* ctx, InvocationResult* result)
+ void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
+ const std::shared_ptr<InvocationResult>& result)
TF_LOCKS_EXCLUDED(*mu_) {
mutex_lock l(*mu_);
num_calls_--;
- RecordBufferEnqueue(ctx, result->return_values);
+ RecordBufferEnqueue(ctx.get(), result->return_values);
result->notification.Notify();
cond_var_->notify_all();
}
- void CallFunction(std::shared_ptr<IteratorContext> ctx,
- InvocationResult* result) TF_LOCKS_EXCLUDED(*mu_) {
+ void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
+ const std::shared_ptr<InvocationResult>& result)
+ TF_LOCKS_EXCLUDED(*mu_) {
profiler::TraceMe traceme([&] {
return profiler::TraceMeEncode("ParseExampleProduce",
{{"element_id", result->id}});
@@ -589,13 +588,13 @@
result->status = input_impl_->GetNext(ctx.get(), &input_element,
&result->end_of_input);
if (result->end_of_input || !result->status.ok()) {
- CallCompleted(ctx.get(), result);
+ CallCompleted(ctx, result);
return;
}
auto done = [this, ctx, result](Status status) {
result->status.Update(status);
- CallCompleted(ctx.get(), result);
+ CallCompleted(ctx, result);
};
// We schedule the `ParseExample` function using `ctx->runner()` to
@@ -724,7 +723,8 @@
return Status::OK();
}
- Status ProcessResult(IteratorContext* ctx, InvocationResult* result,
+ Status ProcessResult(IteratorContext* ctx,
+ const std::shared_ptr<InvocationResult>& result,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) TF_LOCKS_EXCLUDED(*mu_) {
if (!result->end_of_input && result->status.ok()) {
@@ -745,7 +745,7 @@
return result->status;
}
- void RunnerThread(std::shared_ptr<IteratorContext> ctx)
+ void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
TF_LOCKS_EXCLUDED(*mu_) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
@@ -781,7 +781,7 @@
cond_var_->notify_all();
}
for (const auto& call : new_calls) {
- CallFunction(ctx, call.get());
+ CallFunction(ctx, call);
}
new_calls.clear();
}
@@ -819,7 +819,7 @@
return true;
}
- void StatsThread(IteratorContext* ctx) {
+ void StatsThread(const std::shared_ptr<IteratorContext>& ctx) {
for (int64_t step = 0;; ++step) {
int num_calls;
int num_parallel_calls;
diff --git a/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc b/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc
index b0d7ab8..46be263 100644
--- a/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc
@@ -337,7 +337,9 @@
const int64_t uid = -1;
};
- void CallCompleted(BatchResult* result) TF_LOCKS_EXCLUDED(*mu_) {
+ void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
+ const std::shared_ptr<BatchResult>& result)
+ TF_LOCKS_EXCLUDED(*mu_) {
mutex_lock l(*mu_);
num_calls_--;
result->call_finished = true;
@@ -347,7 +349,8 @@
// The function fetches elements from input dataset sequentially and then
// executes the batching for different batches in parallel using the context
// runner.
- void CallBatching(std::shared_ptr<IteratorContext> ctx, BatchResult* result)
+ void CallBatching(std::shared_ptr<IteratorContext> ctx,
+ const std::shared_ptr<BatchResult>& result)
TF_LOCKS_EXCLUDED(*mu_) {
profiler::TraceMe traceme([&] {
return profiler::TraceMeEncode("ParallelBatchProduce",
@@ -355,7 +358,7 @@
});
if (!input_impl_) {
- CallCompleted(result);
+ CallCompleted(ctx, result);
return;
}
@@ -386,7 +389,7 @@
}
if (batch_elements->empty()) {
- CallCompleted(result);
+ CallCompleted(ctx, result);
return;
}
@@ -406,7 +409,7 @@
std::move(allocation_callback), &result->output);
result->status.Update(status);
}
- CallCompleted(result);
+ CallCompleted(ctx, result);
return status;
};
@@ -427,15 +430,14 @@
void EnsureRunnerThreadStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
+ auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_ = ctx->StartThread(
- "tf_data_parallel_batch",
- [this, ctx = std::make_shared<IteratorContext>(*ctx)]() {
- RunnerThread(ctx);
- });
+ kTFDataParallelBatch,
+ std::bind(&Iterator::RunnerThread, this, ctx_copy));
}
}
- void RunnerThread(std::shared_ptr<IteratorContext> ctx)
+ void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
TF_LOCKS_EXCLUDED(*mu_) {
std::vector<std::shared_ptr<BatchResult>> new_calls;
RecordStart(ctx.get());
@@ -470,7 +472,7 @@
}
}
for (const auto& call : new_calls) {
- CallBatching(ctx, call.get());
+ CallBatching(ctx, call);
}
new_calls.clear();
}
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index b67491b..7a5260a 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -257,7 +257,7 @@
return profiler::TraceMeEncode("ParallelMapConsume",
{{"element_id", result->uid}});
});
- return ProcessResult(ctx, result.get(), out_tensors, end_of_sequence);
+ return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
protected:
@@ -395,30 +395,30 @@
void EnsureThreadsStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
+ auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_ = ctx->StartThread(
"tf_data_parallel_map",
- [this, ctx = std::make_shared<IteratorContext>(*ctx)]() {
- RunnerThread(ctx);
- });
+ std::bind(&Iterator::RunnerThread, this, ctx_copy));
if (ctx->stats_aggregator()) {
stats_thread_ = ctx->StartThread(
"tf_data_parallel_map_stats",
- [this, ctx = std::make_shared<IteratorContext>(*ctx)]() {
- StatsThread(ctx.get());
- });
+ std::bind(&Iterator::StatsThread, this, ctx_copy));
}
}
}
- void CallCompleted(InvocationResult* result) TF_LOCKS_EXCLUDED(*mu_) {
+ void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
+ const std::shared_ptr<InvocationResult>& result)
+ TF_LOCKS_EXCLUDED(*mu_) {
mutex_lock l(*mu_);
num_calls_--;
result->notification.Notify();
cond_var_->notify_all();
}
- void CallFunction(std::shared_ptr<IteratorContext> ctx,
- InvocationResult* result) TF_LOCKS_EXCLUDED(*mu_) {
+ void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
+ const std::shared_ptr<InvocationResult>& result)
+ TF_LOCKS_EXCLUDED(*mu_) {
profiler::TraceMe traceme([&] {
return profiler::TraceMeEncode("ParallelMapProduce",
{{"element_id", result->uid}});
@@ -428,14 +428,14 @@
result->status = input_impl_->GetNext(ctx.get(), &input_element,
&result->end_of_input);
if (result->end_of_input || !result->status.ok()) {
- CallCompleted(result);
+ CallCompleted(ctx, result);
return;
}
auto done = [this, ctx, result](Status status) {
result->status.Update(status);
RecordBufferEnqueue(ctx.get(), result->return_values);
- CallCompleted(result);
+ CallCompleted(ctx, result);
};
// Apply the map function on `input_element`, storing the result in
@@ -472,7 +472,8 @@
}
}
- Status ProcessResult(IteratorContext* ctx, InvocationResult* result,
+ Status ProcessResult(IteratorContext* ctx,
+ const std::shared_ptr<InvocationResult>& result,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) TF_LOCKS_EXCLUDED(*mu_) {
if (!result->end_of_input && result->status.ok()) {
@@ -500,7 +501,7 @@
return result->status;
}
- void RunnerThread(std::shared_ptr<IteratorContext> ctx)
+ void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
TF_LOCKS_EXCLUDED(*mu_) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
@@ -533,13 +534,13 @@
cond_var_->notify_all();
}
for (const auto& call : new_calls) {
- CallFunction(ctx, call.get());
+ CallFunction(ctx, call);
}
new_calls.clear();
}
}
- // Determines whether the caller needs to wait for a result-> Upon returning
+ // Determines whether the caller needs to wait for a result. Upon returning
// false, `result` will point to the result.
bool ShouldWait(std::shared_ptr<InvocationResult>* result)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
@@ -571,7 +572,7 @@
return true;
}
- void StatsThread(IteratorContext* ctx) {
+ void StatsThread(const std::shared_ptr<IteratorContext>& ctx) {
for (int64_t step = 0;; ++step) {
int num_calls;
int num_parallel_calls;
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 0be61ef..ba15c60 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -438,11 +438,10 @@
Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!prefetch_thread_) {
+ std::shared_ptr<IteratorContext> new_ctx =
+ std::make_shared<IteratorContext>(*ctx);
prefetch_thread_ = ctx->StartThread(
- "tf_data_prefetch",
- [this, ctx = std::make_shared<IteratorContext>(*ctx)]() {
- PrefetchThread(ctx.get());
- });
+ "tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); });
}
return Status::OK();
}
@@ -450,9 +449,9 @@
// Prefetches elements of the input, storing results in an internal buffer.
//
// It owns the iterator context passed to it.
- void PrefetchThread(IteratorContext* ctx) {
- RecordStart(ctx);
- auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx); });
+ void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
// Keep track of where we are in an iteration "burst"
int num_produced = 0;
while (true) {
@@ -460,9 +459,9 @@
{
mutex_lock l(*mu_);
while (!cancelled_ && buffer_.size() >= buffer_limit()) {
- RecordStop(ctx);
+ RecordStop(ctx.get());
cond_var_->wait(l);
- RecordStart(ctx);
+ RecordStart(ctx.get());
}
if (cancelled_) {
@@ -496,7 +495,7 @@
},
profiler::kInfo);
buffer_element.status = input_impl_->GetNext(
- ctx, &buffer_element.value, &end_of_sequence);
+ ctx.get(), &buffer_element.value, &end_of_sequence);
}
if (buffer_element.status.ok() && end_of_sequence) {
mutex_lock l(*mu_);
@@ -508,7 +507,7 @@
// 3. Signal that the element has been produced.
{
mutex_lock l(*mu_);
- RecordBufferEnqueue(ctx, buffer_element.value);
+ RecordBufferEnqueue(ctx.get(), buffer_element.value);
buffer_element.created_us = EnvTime::NowMicros();
buffer_.push_back(std::move(buffer_element));
cond_var_->notify_all();