Fix tsan failure.
Instead of creating a new threadpool in ReadElementsParallel, we can reuse the existing thread_pool_ field. This solves the tsan failure because now the destructor will block until the threads created in ReadElementsParallel exit.
PiperOrigin-RevId: 325341982
Change-Id: I1107bde215a5384ded98633f5f46a2dde3ff7e23
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 94cc31a..1365f8a 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -623,7 +623,6 @@
name = "parallel_interleave_dataset_op_test",
size = "small",
srcs = ["parallel_interleave_dataset_op_test.cc"],
- tags = ["notsan"], # TODO(b/147147071): Remove this tag once bug fix lands.
deps = [
":captured_function",
":dataset_test_base",
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 54ad888..90dd533 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -41,6 +41,7 @@
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/stringprintf.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/profiler/lib/traceme_encode.h"
@@ -1342,12 +1343,10 @@
IteratorContext* ctx, IteratorStateReader* reader, int64 size,
const string& name, std::vector<std::shared_ptr<Element>>* elements) {
elements->resize(size);
- std::unique_ptr<thread::ThreadPool> threadpool =
- ctx->CreateThreadPool(absl::StrCat("read_", name), size);
Status s = Status::OK();
BlockingCounter counter(size);
for (int idx = 0; idx < size; ++idx) {
- threadpool->Schedule(
+ thread_pool_->Schedule(
[this, ctx, reader, idx, name, &s, &counter, elements] {
RecordStart(ctx);
auto cleanup = gtl::MakeCleanup([this, ctx, &counter]() {
@@ -1357,6 +1356,11 @@
std::shared_ptr<Element> elem;
Status ret_status = ReadElement(ctx, reader, idx, name, &elem);
mutex_lock l(*mu_);
+ if (cancelled_) {
+ s.Update(
+ errors::Cancelled("Cancelled in ReadElementsParallel"));
+ return;
+ }
if (!ret_status.ok()) {
s.Update(ret_status);
return;