[tf.data] Use the cached finalized dataset when finalizing during sequential access.
PiperOrigin-RevId: 426551686
Change-Id: Iac55dd07e9a2c428cd6028141a6b2ed013c2d935
diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD
index 432651c..3c460e0 100644
--- a/tensorflow/core/data/BUILD
+++ b/tensorflow/core/data/BUILD
@@ -24,6 +24,8 @@
"captured_function.h",
"dataset_utils.cc",
"dataset_utils.h",
+ "finalization_utils.cc",
+ "finalization_utils.h",
"name_utils.cc",
"name_utils.h",
"rewrite_utils.cc",
diff --git a/tensorflow/core/data/finalization_utils.cc b/tensorflow/core/data/finalization_utils.cc
index 359e139..64296bf 100644
--- a/tensorflow/core/data/finalization_utils.cc
+++ b/tensorflow/core/data/finalization_utils.cc
@@ -23,7 +23,7 @@
namespace data {
StatusOr<DatasetBase*> GetFinalizedDataset(OpKernelContext* ctx,
- DatasetBase* dataset) {
+ const DatasetBase* dataset) {
return dataset->Finalize(
ctx, [ctx, dataset]() -> StatusOr<core::RefCountPtr<DatasetBase>> {
core::RefCountPtr<DatasetBase> dataset_ref_ptr;
diff --git a/tensorflow/core/data/finalization_utils.h b/tensorflow/core/data/finalization_utils.h
index 9edfa0a..f019c20 100644
--- a/tensorflow/core/data/finalization_utils.h
+++ b/tensorflow/core/data/finalization_utils.h
@@ -24,9 +24,10 @@
namespace tensorflow {
namespace data {
-// Returns the finalized version of the dataset.
+// Returns the finalized version of the dataset. The returned DatasetBase is
+// unowned and lives for as long as this dataset.
StatusOr<DatasetBase*> GetFinalizedDataset(OpKernelContext* ctx,
- DatasetBase* dataset);
+ const DatasetBase* dataset);
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/data/root_dataset.cc b/tensorflow/core/data/root_dataset.cc
index 4e4240b..79c93f3 100644
--- a/tensorflow/core/data/root_dataset.cc
+++ b/tensorflow/core/data/root_dataset.cc
@@ -15,6 +15,10 @@
#include "tensorflow/core/data/root_dataset.h"
+#include <functional>
+#include <string>
+#include <utility>
+
#include "tensorflow/core/data/dataset_utils.h"
#include "tensorflow/core/data/name_utils.h"
#include "tensorflow/core/data/rewrite_utils.h"
@@ -45,39 +49,84 @@
return x == y ? z : x;
}
+void SetRootDatasetParams(const Options& options, RootDataset::Params* params) {
+ if (ShouldConfigureMaxIntraOpParallelism(options)) {
+ params->max_intra_op_parallelism =
+ options.threading_options().max_intra_op_parallelism();
+ }
+ if (ShouldUsePrivateThreadPool(options)) {
+ params->private_threadpool_size =
+ options.threading_options().private_threadpool_size();
+ }
+ params->autotune = ShouldUseAutotuning(options);
+ if (params->autotune) {
+ params->autotune_algorithm =
+ options.autotune_options().optional_autotune_algorithm_case() ==
+ AutotuneOptions::kAutotuneAlgorithm
+ ? options.autotune_options().autotune_algorithm()
+ : model::AutotuneAlgorithm::DEFAULT;
+ params->autotune_cpu_budget = value_or_default(
+ options.autotune_options().cpu_budget(), 0, GetCpuBudget());
+ params->autotune_ram_budget =
+ value_or_default(options.autotune_options().ram_budget(), 0,
+ model::kRamBudgetShare * port::AvailableRam());
+ }
+}
+
+void AddTraceMetadata(const RootDataset::Params& params,
+ TraceMeMetadata* trace_metadata) {
+ if (params.autotune) {
+ trace_metadata->push_back(std::make_pair(
+ kAlgorithm, model::AutotuneAlgorithm_Name(params.autotune_algorithm)));
+ trace_metadata->push_back(std::make_pair(
+ kCpuBudget, strings::Printf("%lld", static_cast<long long>(
+ params.autotune_cpu_budget))));
+ trace_metadata->push_back(std::make_pair(
+ kRamBudget,
+ strings::Printf("%lld", static_cast<long long>(
+ params.autotune_ram_budget / 1.0e6))));
+ }
+ if (params.max_intra_op_parallelism >= 0) {
+ trace_metadata->push_back(std::make_pair(
+ kIntraOpParallelism,
+ strings::Printf("%lld", static_cast<long long>(value_or_default(
+ params.max_intra_op_parallelism, 0,
+ port::MaxParallelism())))));
+ }
+ if (params.private_threadpool_size >= 0) {
+ trace_metadata->push_back(std::make_pair(
+ kPrivateThreadpoolSize,
+ strings::Printf("%lld", static_cast<long long>(value_or_default(
+ params.private_threadpool_size, 0,
+ port::MaxParallelism())))));
+ }
+ auto experiments = GetExperiments();
+ if (!experiments.empty()) {
+ trace_metadata->push_back(
+ std::make_pair(kExperiments, absl::StrJoin(experiments, " ")));
+ }
+}
} // namespace
// static
Status RootDataset::FromOptions(const DatasetBase* input,
DatasetBase** output) {
- const Options& options = input->options();
Params params;
- if (ShouldConfigureMaxIntraOpParallelism(options)) {
- params.max_intra_op_parallelism =
- options.threading_options().max_intra_op_parallelism();
- }
- if (ShouldUsePrivateThreadPool(options)) {
- params.private_threadpool_size =
- options.threading_options().private_threadpool_size();
- }
- params.autotune = ShouldUseAutotuning(options);
- if (params.autotune) {
- params.autotune_algorithm =
- options.autotune_options().optional_autotune_algorithm_case() ==
- AutotuneOptions::kAutotuneAlgorithm
- ? options.autotune_options().autotune_algorithm()
- : model::AutotuneAlgorithm::DEFAULT;
- params.autotune_cpu_budget = value_or_default(
- options.autotune_options().cpu_budget(), 0, GetCpuBudget());
- params.autotune_ram_budget =
- value_or_default(options.autotune_options().ram_budget(), 0,
- model::kRamBudgetShare * port::AvailableRam());
- }
+ SetRootDatasetParams(input->options(), ¶ms);
*output = new RootDataset(input, params);
(*output)->Initialize(/*metadata=*/{});
return Status::OK();
}
+Status RootDataset::FromOptions(core::RefCountPtr<DatasetBase> input,
+ DatasetBase** output) {
+ Params params;
+ SetRootDatasetParams(input->options(), ¶ms);
+ *output = new RootDataset(std::move(input), params);
+ (*output)->Initialize(/*metadata=*/{});
+ return Status::OK();
+}
+
class RootDataset::Iterator : public DatasetIterator<RootDataset> {
public:
explicit Iterator(const Params& params)
@@ -215,45 +264,25 @@
std::unique_ptr<IteratorBase> input_impl_;
};
-RootDataset::RootDataset(const DatasetBase* input, Params params)
+RootDataset::RootDataset(const DatasetBase* input, const Params& params)
: DatasetBase(DatasetContext({name_utils::OpName(kDatasetType),
name_utils::OpName(kDatasetType)})),
input_(input),
params_(std::move(params)) {
- if (params_.autotune) {
- traceme_metadata_.push_back(std::make_pair(
- kAlgorithm, model::AutotuneAlgorithm_Name(params_.autotune_algorithm)));
- traceme_metadata_.push_back(std::make_pair(
- kCpuBudget, strings::Printf("%lld", static_cast<long long>(
- params_.autotune_cpu_budget))));
- traceme_metadata_.push_back(std::make_pair(
- kRamBudget,
- strings::Printf("%lld", static_cast<long long>(
- params_.autotune_ram_budget / 1.0e6))));
- }
- if (params_.max_intra_op_parallelism >= 0) {
- traceme_metadata_.push_back(std::make_pair(
- kIntraOpParallelism,
- strings::Printf("%lld", static_cast<long long>(value_or_default(
- params_.max_intra_op_parallelism, 0,
- port::MaxParallelism())))));
- }
- if (params_.private_threadpool_size >= 0) {
- traceme_metadata_.push_back(std::make_pair(
- kPrivateThreadpoolSize,
- strings::Printf("%lld", static_cast<long long>(value_or_default(
- params_.private_threadpool_size, 0,
- port::MaxParallelism())))));
- }
- auto experiments = GetExperiments();
- if (!experiments.empty()) {
- traceme_metadata_.push_back(
- std::make_pair(kExperiments, absl::StrJoin(experiments, " ")));
- }
- input_->Ref();
+ AddTraceMetadata(params_, &traceme_metadata_);
}
-RootDataset::~RootDataset() { input_->Unref(); }
+RootDataset::RootDataset(core::RefCountPtr<DatasetBase> input,
+ const Params& params)
+ : DatasetBase(DatasetContext({name_utils::OpName(kDatasetType),
+ name_utils::OpName(kDatasetType)})),
+ params_(std::move(params)) {
+ owned_input_ = std::move(input);
+ input_ = owned_input_.get();
+ AddTraceMetadata(params_, &traceme_metadata_);
+}
+
+RootDataset::~RootDataset() {}
std::unique_ptr<IteratorBase> RootDataset::MakeIteratorInternal(
const string& prefix) const {
@@ -337,20 +366,25 @@
};
Status s = RewriteDataset(ctx, input, std::move(config_factory),
/*record_fingerprint=*/true, output);
+ bool rewritten = (*output != input);
if (errors::IsDeadlineExceeded(s)) {
// Ignore DeadlineExceeded as it implies that the attempted rewrite took too
// long which should not prevent further computation.
LOG(WARNING) << s.ToString();
- return RootDataset::FromOptions(input, output);
- }
- if (!s.ok()) {
+ } else if (!s.ok()) {
return s;
}
- input = *output;
- TF_RETURN_IF_ERROR(RootDataset::FromOptions(input, output));
- input->Unref();
+ if (!rewritten) {
+ return RootDataset::FromOptions(input, output);
+ } else {
+ input = *output;
+ core::RefCountPtr<DatasetBase> dataset_ref_ptr(
+ const_cast<DatasetBase*>(input));
+ return RootDataset::FromOptions(std::move(dataset_ref_ptr), output);
+ }
return Status::OK();
}
+
#else // !IS_MOBILE_PLATFORM
Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input,
DatasetBase** output) {
diff --git a/tensorflow/core/data/root_dataset.h b/tensorflow/core/data/root_dataset.h
index 11e28e4..70c9708 100644
--- a/tensorflow/core/data/root_dataset.h
+++ b/tensorflow/core/data/root_dataset.h
@@ -18,6 +18,7 @@
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/model.h"
#include "tensorflow/core/framework/model.pb.h"
+#include "tensorflow/core/platform/refcount.h"
namespace tensorflow {
namespace data {
@@ -36,6 +37,8 @@
};
static Status FromOptions(const DatasetBase* input, DatasetBase** output);
+ static Status FromOptions(core::RefCountPtr<DatasetBase> input,
+ DatasetBase** output);
~RootDataset() override;
@@ -60,9 +63,12 @@
private:
class Iterator;
- RootDataset(const DatasetBase* input, Params params);
+ RootDataset(const DatasetBase* input, const Params& params);
+
+ RootDataset(core::RefCountPtr<DatasetBase> input, const Params& params);
const DatasetBase* input_;
+ core::RefCountPtr<DatasetBase> owned_input_;
const Params params_;
TraceMeMetadata traceme_metadata_;
};
@@ -70,7 +76,8 @@
// Finalizes the `input` dataset, which is expected to be called before the
// dataset is about to be iterated. This can for instance apply static graph
// optimizations or inject internal tf.data transformations responsible for
-// autotuning or threading configuration.
+// autotuning or threading configuration. The caller must ensure that the
+// input dataset to be finalized outlives the output.
Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input,
DatasetBase** output);
diff --git a/tensorflow/core/data/standalone.cc b/tensorflow/core/data/standalone.cc
index 97ee2a0..8e0abe8 100644
--- a/tensorflow/core/data/standalone.cc
+++ b/tensorflow/core/data/standalone.cc
@@ -116,9 +116,9 @@
OpKernelContext ctx(&op_params, /*num_outputs=*/0);
TF_RETURN_IF_ERROR(data::FinalizeDataset(&ctx, dataset, &finalized_dataset));
core::ScopedUnref unref(finalized_dataset);
- *result = WrapUnique(new Dataset(finalized_dataset, device_mgr.release(),
- pflr.release(), flib_def.release(),
- pool.release(), std::move(runner)));
+ *result = WrapUnique(new Dataset(
+ finalized_dataset, dataset, device_mgr.release(), pflr.release(),
+ flib_def.release(), pool.release(), std::move(runner)));
return Status::OK();
} // static
@@ -146,8 +146,8 @@
// Create the iterator from the dataset.
std::unique_ptr<IteratorBase> iterator;
- TF_RETURN_IF_ERROR(dataset_->MakeIterator(ctx.get(), /*parent=*/nullptr,
- "Iterator", &iterator));
+ TF_RETURN_IF_ERROR(finalized_dataset_->MakeIterator(
+ ctx.get(), /*parent=*/nullptr, "Iterator", &iterator));
*result = WrapUnique(new Iterator(iterator.release(), ctx.release()));
return Status::OK();
@@ -159,28 +159,33 @@
Status Dataset::MakeSplitProviders(
std::vector<std::unique_ptr<SplitProvider>>* result) {
- return dataset_->MakeSplitProviders(result);
+ return finalized_dataset_->MakeSplitProviders(result);
}
-const DatasetBase* Dataset::Get() const { return dataset_; }
+const DatasetBase* Dataset::Get() const { return finalized_dataset_; }
-Dataset::Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,
- ProcessFunctionLibraryRuntime* pflr,
+Dataset::Dataset(DatasetBase* finalized_dataset, DatasetBase* original_dataset,
+ DeviceMgr* device_mgr, ProcessFunctionLibraryRuntime* pflr,
FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool,
std::function<void(std::function<void()>)> runner)
- : dataset_(dataset),
+ : finalized_dataset_(finalized_dataset),
+ original_dataset_(original_dataset),
device_mgr_(device_mgr),
flib_def_(flib_def),
pflr_(pflr),
interop_threadpool_(pool),
runner_(std::move(runner)),
unbounded_thread_pool_(Env::Default(), "tf_data_standalone") {
- dataset_->Ref();
+ finalized_dataset_->Ref();
+ original_dataset_->Ref();
function_handle_cache_ =
absl::make_unique<FunctionHandleCache>(pflr_->GetFLR("/device:CPU:0"));
}
-Dataset::~Dataset() { dataset_->Unref(); }
+Dataset::~Dataset() {
+ finalized_dataset_->Unref();
+ original_dataset_->Unref();
+}
} // namespace standalone
} // namespace data
diff --git a/tensorflow/core/data/standalone.h b/tensorflow/core/data/standalone.h
index c641cac..c35e6a5 100644
--- a/tensorflow/core/data/standalone.h
+++ b/tensorflow/core/data/standalone.h
@@ -112,12 +112,13 @@
const DatasetBase* Get() const;
private:
- Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,
- ProcessFunctionLibraryRuntime* pflr,
+ Dataset(DatasetBase* finalized_dataset, DatasetBase* original_dataset,
+ DeviceMgr* device_mgr, ProcessFunctionLibraryRuntime* pflr,
FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool,
std::function<void(std::function<void()>)> runner);
- DatasetBase* dataset_; // owned
+ DatasetBase* finalized_dataset_; // owned
+ DatasetBase* original_dataset_; // owned
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index 79c89a6..3533e4a 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -646,7 +646,7 @@
StatusOr<DatasetBase*> DatasetBase::Finalize(
OpKernelContext* ctx,
std::function<StatusOr<core::RefCountPtr<DatasetBase>>()>
- make_finalized_dataset) {
+ make_finalized_dataset) const {
mutex_lock l(mu_);
if (!finalized_dataset_) {
TF_ASSIGN_OR_RETURN(finalized_dataset_, make_finalized_dataset());
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 36db239..120d206 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -1032,7 +1032,7 @@
virtual StatusOr<DatasetBase*> Finalize(
OpKernelContext* ctx,
std::function<StatusOr<core::RefCountPtr<DatasetBase>>()>
- make_finalized_dataset);
+ make_finalized_dataset) const;
// Wrapper around a GraphDefBuilder which provides support for serializing
// Datasets as GraphDefs.
@@ -1099,9 +1099,9 @@
const string node_name_;
Metadata metadata_;
Options options_;
- mutex mu_;
+ mutable mutex mu_;
mutable mutex cardinality_mu_;
- core::RefCountPtr<DatasetBase> finalized_dataset_;
+ mutable core::RefCountPtr<DatasetBase> finalized_dataset_;
// The number of source datasets feeding into the dataset. A source dataset
// is a leaf in the subtree of dataset inputs.
int64_t num_sources_ = -1;
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 99bbe51..bcafc82 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -58,6 +58,7 @@
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core/data:dataset_utils",
"//tensorflow/core/data:name_utils",
"//tensorflow/core/data:serialization_utils",
"//tensorflow/core/framework:dataset_options_proto_cc",
@@ -419,6 +420,7 @@
"//tensorflow/core:session_options",
"//tensorflow/core/data:captured_function",
"//tensorflow/core/data:dataset_utils",
+ "//tensorflow/core/data:finalization_utils",
"//tensorflow/core/data:root_dataset",
"//tensorflow/core/data:serialization_utils",
"//tensorflow/core/data:unbounded_thread_pool",
@@ -530,6 +532,7 @@
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/data:dataset_utils",
+ "//tensorflow/core/data:finalization_utils",
"//tensorflow/core/data:root_dataset",
"//tensorflow/core/data:unbounded_thread_pool",
"//tensorflow/core/kernels:ops_util",
@@ -1367,6 +1370,7 @@
srcs = [
"//tensorflow/core/data:captured_function.h",
"//tensorflow/core/data:dataset_utils.h",
+ "//tensorflow/core/data:finalization_utils.h",
"//tensorflow/core/data:utils.h",
"//tensorflow/core/data:name_utils.h",
"//tensorflow/core/data:rewrite_utils.h",
@@ -1391,6 +1395,7 @@
":portable_all_op_kernels_headers",
"//tensorflow/core/data:captured_function.cc",
"//tensorflow/core/data:dataset_utils.cc",
+ "//tensorflow/core/data:finalization_utils.cc",
"//tensorflow/core/data:name_utils.cc",
"//tensorflow/core/data:utils.cc",
"//tensorflow/core/data:rewrite_utils.cc",
diff --git a/tensorflow/core/kernels/data/experimental/random_access_ops.h b/tensorflow/core/kernels/data/experimental/random_access_ops.h
index 3450ff5..bcc54ba 100644
--- a/tensorflow/core/kernels/data/experimental/random_access_ops.h
+++ b/tensorflow/core/kernels/data/experimental/random_access_ops.h
@@ -34,6 +34,8 @@
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
+ ~GetElementAtIndexOp() override {}
+
protected:
Status DoCompute(OpKernelContext* ctx) override;
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 5511d9f..5bb26d7 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -27,6 +27,7 @@
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/data/captured_function.h"
#include "tensorflow/core/data/dataset_utils.h"
+#include "tensorflow/core/data/finalization_utils.h"
#include "tensorflow/core/data/root_dataset.h"
#include "tensorflow/core/data/serialization_utils.h"
#include "tensorflow/core/data/utils.h"
@@ -194,6 +195,7 @@
IteratorStateReader* reader) {
const DatasetBase* dataset;
std::shared_ptr<State> new_state;
+ const DatasetBase* input_dataset;
{
tf_shared_lock l(mu_);
if (!iterator_state_->iterator()) {
@@ -211,6 +213,7 @@
std::make_shared<State>(iterator_state_->flib_def(),
iterator_state_->pflr(), iterator_state_->flr(),
/*iterator=*/nullptr);
+ input_dataset = iterator_state_->dataset();
}
core::ScopedUnref scoped_unref(dataset);
IteratorContext::Params params(ctx);
@@ -229,7 +232,8 @@
std::unique_ptr<IteratorBase> iterator_base;
TF_RETURN_IF_ERROR(dataset->MakeIteratorFromCheckpoint(
IteratorContext(std::move(params)), "Iterator", reader, &iterator_base));
- new_state->DowncastAndSetIterator(std::move(iterator_base));
+ new_state->DowncastAndSetIteratorAndDataset(std::move(iterator_base),
+ input_dataset);
mutex_lock l(mu_);
std::swap(iterator_state_, new_state);
@@ -265,8 +269,7 @@
std::unique_ptr<IteratorBase> iterator;
if (ctx->function_library()->device()->device_type() == DEVICE_CPU) {
DatasetBase* finalized_dataset;
- TF_RETURN_IF_ERROR(FinalizeDataset(ctx, dataset, &finalized_dataset));
- core::ScopedUnref unref(finalized_dataset);
+ TF_ASSIGN_OR_RETURN(finalized_dataset, GetFinalizedDataset(ctx, dataset));
TF_RETURN_IF_ERROR(finalized_dataset->MakeIterator(
IteratorContext(std::move(params)),
/*parent=*/nullptr, "Iterator", &iterator));
@@ -280,7 +283,7 @@
TF_RETURN_IF_ERROR(
VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
- new_state->DowncastAndSetIterator(std::move(iterator));
+ new_state->DowncastAndSetIteratorAndDataset(std::move(iterator), dataset);
mutex_lock l(mu_);
std::swap(iterator_state_, new_state);
diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h
index 6fc9339..5c3188e 100644
--- a/tensorflow/core/kernels/data/iterator_ops.h
+++ b/tensorflow/core/kernels/data/iterator_ops.h
@@ -16,6 +16,8 @@
#ifndef TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
#define TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
+#include <utility>
+
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/data/dataset_utils.h"
#include "tensorflow/core/data/unbounded_thread_pool.h"
@@ -26,6 +28,7 @@
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/refcount.h"
namespace tensorflow {
namespace data {
@@ -89,9 +92,14 @@
~State() { cancellation_manager_.StartCancel(); }
// Downcasts the given `IteratorBase` to a `DatasetBaseIterator`, and uses
- // it to set the `iterator` field.
- void DowncastAndSetIterator(std::unique_ptr<IteratorBase> it) {
+ // it to set the `iterator` and the `dataset` field.
+ void DowncastAndSetIteratorAndDataset(std::unique_ptr<IteratorBase> it,
+ const DatasetBase* dataset) {
iterator_.reset(static_cast<DatasetBaseIterator*>(it.release()));
+ if (dataset) {
+ dataset->Ref();
+ dataset_.reset(const_cast<DatasetBase*>(dataset));
+ }
}
std::shared_ptr<FunctionLibraryDefinition> flib_def() { return flib_def_; }
@@ -112,6 +120,8 @@
DatasetBaseIterator* iterator() { return iterator_.get(); }
+ DatasetBase* dataset() { return dataset_.get(); }
+
private:
std::shared_ptr<FunctionLibraryDefinition> flib_def_;
FunctionLibraryRuntime* flr_ = nullptr; // not owned
@@ -120,6 +130,7 @@
ResourceMgr resource_mgr_;
CancellationManager cancellation_manager_;
std::unique_ptr<DatasetBaseIterator> iterator_;
+ core::RefCountPtr<DatasetBase> dataset_;
};
UnboundedThreadPool unbounded_thread_pool_;
diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
index 7f3fdc1..3151e04 100644
--- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
+++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
@@ -92,7 +92,7 @@
}
Status Init(std::unique_ptr<IteratorBase> iterator, int64_t max_buffer_size,
- int64_t* incarnation_id) {
+ int64_t* incarnation_id, DatasetBase* dataset) {
if (iterator) {
TF_RETURN_IF_ERROR(
VerifyTypesMatch(output_types_, iterator->output_dtypes()));
@@ -104,6 +104,8 @@
if (multi_device_buffer_) {
multi_device_buffer_->Reset();
}
+ dataset->Ref();
+ dataset_.reset(dataset);
++incarnation_id_;
*incarnation_id = incarnation_id_;
@@ -450,6 +452,7 @@
int64_t incarnation_id_ TF_GUARDED_BY(mu_) = 0;
std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ TF_GUARDED_BY(mu_);
+ core::RefCountPtr<DatasetBase> dataset_;
};
// Used to generate unique names for anonymous multi device iterators.
@@ -654,7 +657,7 @@
core::ScopedUnref unref(finalized_dataset);
int64_t incarnation_id;
OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
- &incarnation_id));
+ &incarnation_id, dataset));
Tensor tensor_incarnation_id(DT_INT64, TensorShape({}));
tensor_incarnation_id.scalar<int64_t>()() = incarnation_id;
OP_REQUIRES_OK(ctx,
@@ -825,9 +828,9 @@
void Compute(OpKernelContext* ctx) override {
ResourceHandle handle = ctx->input(0).flat<ResourceHandle>()(0);
- // The iterator resource is guaranteed to exist because the variant tensor
- // wrapping the deleter is provided as an unused input to this op, which
- // guarantees that it has not run yet.
+ // The iterator resource is guaranteed to
+ // exist because the variant tensor wrapping the deleter is provided as an
+ // unused input to this op, which guarantees that it has not run yet.
OP_REQUIRES_OK(ctx, DeleteResource(ctx, handle));
}
};
diff --git a/tensorflow/python/data/kernel_tests/shuffle_test.py b/tensorflow/python/data/kernel_tests/shuffle_test.py
index 6f442ea..3669dfa 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_test.py
@@ -377,23 +377,27 @@
self.evaluate(get_next())
@combinations.generate(
- combinations.times(
- test_base.default_test_combinations(),
- combinations.combine(reshuffle=[True, False])))
- def testRerandomizeOnReplicate(self, reshuffle):
+ combinations.times(test_base.default_test_combinations(),
+ combinations.combine(reshuffle=[True, False])))
+ def testDontRerandomizeOnReplicate(self, reshuffle):
random_seed.set_random_seed(None)
- # When no seeds are fixed, each instantiation of the shuffle dataset should
- # produce elements in a different order.
+ # Since the seed generator configuration is preserved across serialization
+ # of the dataset, each instantiation of the shuffle dataset
+ # should preserve the shuffle order if reshuffle=False. To preserve the
+ # shuffle order, the original dataset must be kept alive, since if the
+ # original dataset was destroyed, its seeds would also be destroyed.
num_elements = 100
- dataset = dataset_ops.Dataset.range(num_elements)
- dataset = dataset.shuffle(num_elements, reshuffle_each_iteration=reshuffle)
+ dataset_1 = dataset_ops.Dataset.range(num_elements)
+ dataset_2 = dataset_1.shuffle(
+ num_elements, reshuffle_each_iteration=reshuffle)
- shuffle_1 = self.getDatasetOutput(dataset)
- dataset = self.graphRoundTrip(dataset, allow_stateful=True)
- shuffle_2 = self.getDatasetOutput(dataset)
+ shuffle_1 = self.getDatasetOutput(dataset_2)
+ dataset_3 = self.graphRoundTrip(dataset_2, allow_stateful=True)
+ shuffle_2 = self.getDatasetOutput(dataset_3)
self.assertCountEqual(shuffle_1, shuffle_2)
- self.assertNotEqual(shuffle_1, shuffle_2)
+ if reshuffle:
+ self.assertNotEqual(shuffle_1, shuffle_2)
@combinations.generate(test_base.eager_only_combinations())
def testCheckpointLargeShuffleBuffer(self):