[tf.data] Use the cached finalized dataset when finalizing during sequential access.

PiperOrigin-RevId: 426551686
Change-Id: Iac55dd07e9a2c428cd6028141a6b2ed013c2d935
diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD
index 432651c..3c460e0 100644
--- a/tensorflow/core/data/BUILD
+++ b/tensorflow/core/data/BUILD
@@ -24,6 +24,8 @@
     "captured_function.h",
     "dataset_utils.cc",
     "dataset_utils.h",
+    "finalization_utils.cc",
+    "finalization_utils.h",
     "name_utils.cc",
     "name_utils.h",
     "rewrite_utils.cc",
diff --git a/tensorflow/core/data/finalization_utils.cc b/tensorflow/core/data/finalization_utils.cc
index 359e139..64296bf 100644
--- a/tensorflow/core/data/finalization_utils.cc
+++ b/tensorflow/core/data/finalization_utils.cc
@@ -23,7 +23,7 @@
 namespace data {
 
 StatusOr<DatasetBase*> GetFinalizedDataset(OpKernelContext* ctx,
-                                           DatasetBase* dataset) {
+                                           const DatasetBase* dataset) {
   return dataset->Finalize(
       ctx, [ctx, dataset]() -> StatusOr<core::RefCountPtr<DatasetBase>> {
         core::RefCountPtr<DatasetBase> dataset_ref_ptr;
diff --git a/tensorflow/core/data/finalization_utils.h b/tensorflow/core/data/finalization_utils.h
index 9edfa0a..f019c20 100644
--- a/tensorflow/core/data/finalization_utils.h
+++ b/tensorflow/core/data/finalization_utils.h
@@ -24,9 +24,10 @@
 namespace tensorflow {
 namespace data {
 
-// Returns the finalized version of the dataset.
+// Returns the finalized version of the dataset. The returned DatasetBase is
+// unowned and lives for as long as this dataset.
 StatusOr<DatasetBase*> GetFinalizedDataset(OpKernelContext* ctx,
-                                           DatasetBase* dataset);
+                                           const DatasetBase* dataset);
 
 }  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/data/root_dataset.cc b/tensorflow/core/data/root_dataset.cc
index 4e4240b..79c93f3 100644
--- a/tensorflow/core/data/root_dataset.cc
+++ b/tensorflow/core/data/root_dataset.cc
@@ -15,6 +15,10 @@
 
 #include "tensorflow/core/data/root_dataset.h"
 
+#include <functional>
+#include <string>
+#include <utility>
+
 #include "tensorflow/core/data/dataset_utils.h"
 #include "tensorflow/core/data/name_utils.h"
 #include "tensorflow/core/data/rewrite_utils.h"
@@ -45,39 +49,84 @@
   return x == y ? z : x;
 }
 
+void SetRootDatasetParams(const Options& options, RootDataset::Params* params) {
+  if (ShouldConfigureMaxIntraOpParallelism(options)) {
+    params->max_intra_op_parallelism =
+        options.threading_options().max_intra_op_parallelism();
+  }
+  if (ShouldUsePrivateThreadPool(options)) {
+    params->private_threadpool_size =
+        options.threading_options().private_threadpool_size();
+  }
+  params->autotune = ShouldUseAutotuning(options);
+  if (params->autotune) {
+    params->autotune_algorithm =
+        options.autotune_options().optional_autotune_algorithm_case() ==
+                AutotuneOptions::kAutotuneAlgorithm
+            ? options.autotune_options().autotune_algorithm()
+            : model::AutotuneAlgorithm::DEFAULT;
+    params->autotune_cpu_budget = value_or_default(
+        options.autotune_options().cpu_budget(), 0, GetCpuBudget());
+    params->autotune_ram_budget =
+        value_or_default(options.autotune_options().ram_budget(), 0,
+                         model::kRamBudgetShare * port::AvailableRam());
+  }
+}
+
+void AddTraceMetadata(const RootDataset::Params& params,
+                      TraceMeMetadata* trace_metadata) {
+  if (params.autotune) {
+    trace_metadata->push_back(std::make_pair(
+        kAlgorithm, model::AutotuneAlgorithm_Name(params.autotune_algorithm)));
+    trace_metadata->push_back(std::make_pair(
+        kCpuBudget, strings::Printf("%lld", static_cast<long long>(
+                                                params.autotune_cpu_budget))));
+    trace_metadata->push_back(std::make_pair(
+        kRamBudget,
+        strings::Printf("%lld", static_cast<long long>(
+                                    params.autotune_ram_budget / 1.0e6))));
+  }
+  if (params.max_intra_op_parallelism >= 0) {
+    trace_metadata->push_back(std::make_pair(
+        kIntraOpParallelism,
+        strings::Printf("%lld", static_cast<long long>(value_or_default(
+                                    params.max_intra_op_parallelism, 0,
+                                    port::MaxParallelism())))));
+  }
+  if (params.private_threadpool_size >= 0) {
+    trace_metadata->push_back(std::make_pair(
+        kPrivateThreadpoolSize,
+        strings::Printf("%lld", static_cast<long long>(value_or_default(
+                                    params.private_threadpool_size, 0,
+                                    port::MaxParallelism())))));
+  }
+  auto experiments = GetExperiments();
+  if (!experiments.empty()) {
+    trace_metadata->push_back(
+        std::make_pair(kExperiments, absl::StrJoin(experiments, " ")));
+  }
+}
 }  // namespace
 
 // static
 Status RootDataset::FromOptions(const DatasetBase* input,
                                 DatasetBase** output) {
-  const Options& options = input->options();
   Params params;
-  if (ShouldConfigureMaxIntraOpParallelism(options)) {
-    params.max_intra_op_parallelism =
-        options.threading_options().max_intra_op_parallelism();
-  }
-  if (ShouldUsePrivateThreadPool(options)) {
-    params.private_threadpool_size =
-        options.threading_options().private_threadpool_size();
-  }
-  params.autotune = ShouldUseAutotuning(options);
-  if (params.autotune) {
-    params.autotune_algorithm =
-        options.autotune_options().optional_autotune_algorithm_case() ==
-                AutotuneOptions::kAutotuneAlgorithm
-            ? options.autotune_options().autotune_algorithm()
-            : model::AutotuneAlgorithm::DEFAULT;
-    params.autotune_cpu_budget = value_or_default(
-        options.autotune_options().cpu_budget(), 0, GetCpuBudget());
-    params.autotune_ram_budget =
-        value_or_default(options.autotune_options().ram_budget(), 0,
-                         model::kRamBudgetShare * port::AvailableRam());
-  }
+  SetRootDatasetParams(input->options(), &params);
   *output = new RootDataset(input, params);
   (*output)->Initialize(/*metadata=*/{});
   return Status::OK();
 }
 
+Status RootDataset::FromOptions(core::RefCountPtr<DatasetBase> input,
+                                DatasetBase** output) {
+  Params params;
+  SetRootDatasetParams(input->options(), &params);
+  *output = new RootDataset(std::move(input), params);
+  (*output)->Initialize(/*metadata=*/{});
+  return Status::OK();
+}
+
 class RootDataset::Iterator : public DatasetIterator<RootDataset> {
  public:
   explicit Iterator(const Params& params)
@@ -215,45 +264,25 @@
   std::unique_ptr<IteratorBase> input_impl_;
 };
 
-RootDataset::RootDataset(const DatasetBase* input, Params params)
+RootDataset::RootDataset(const DatasetBase* input, const Params& params)
     : DatasetBase(DatasetContext({name_utils::OpName(kDatasetType),
                                   name_utils::OpName(kDatasetType)})),
       input_(input),
       params_(std::move(params)) {
-  if (params_.autotune) {
-    traceme_metadata_.push_back(std::make_pair(
-        kAlgorithm, model::AutotuneAlgorithm_Name(params_.autotune_algorithm)));
-    traceme_metadata_.push_back(std::make_pair(
-        kCpuBudget, strings::Printf("%lld", static_cast<long long>(
-                                                params_.autotune_cpu_budget))));
-    traceme_metadata_.push_back(std::make_pair(
-        kRamBudget,
-        strings::Printf("%lld", static_cast<long long>(
-                                    params_.autotune_ram_budget / 1.0e6))));
-  }
-  if (params_.max_intra_op_parallelism >= 0) {
-    traceme_metadata_.push_back(std::make_pair(
-        kIntraOpParallelism,
-        strings::Printf("%lld", static_cast<long long>(value_or_default(
-                                    params_.max_intra_op_parallelism, 0,
-                                    port::MaxParallelism())))));
-  }
-  if (params_.private_threadpool_size >= 0) {
-    traceme_metadata_.push_back(std::make_pair(
-        kPrivateThreadpoolSize,
-        strings::Printf("%lld", static_cast<long long>(value_or_default(
-                                    params_.private_threadpool_size, 0,
-                                    port::MaxParallelism())))));
-  }
-  auto experiments = GetExperiments();
-  if (!experiments.empty()) {
-    traceme_metadata_.push_back(
-        std::make_pair(kExperiments, absl::StrJoin(experiments, " ")));
-  }
-  input_->Ref();
+  AddTraceMetadata(params_, &traceme_metadata_);
 }
 
-RootDataset::~RootDataset() { input_->Unref(); }
+RootDataset::RootDataset(core::RefCountPtr<DatasetBase> input,
+                         const Params& params)
+    : DatasetBase(DatasetContext({name_utils::OpName(kDatasetType),
+                                  name_utils::OpName(kDatasetType)})),
+      params_(std::move(params)) {
+  owned_input_ = std::move(input);
+  input_ = owned_input_.get();
+  AddTraceMetadata(params_, &traceme_metadata_);
+}
+
+RootDataset::~RootDataset() {}
 
 std::unique_ptr<IteratorBase> RootDataset::MakeIteratorInternal(
     const string& prefix) const {
@@ -337,20 +366,25 @@
   };
   Status s = RewriteDataset(ctx, input, std::move(config_factory),
                             /*record_fingerprint=*/true, output);
+  bool rewritten = (*output != input);
   if (errors::IsDeadlineExceeded(s)) {
     // Ignore DeadlineExceeded as it implies that the attempted rewrite took too
     // long which should not prevent further computation.
     LOG(WARNING) << s.ToString();
-    return RootDataset::FromOptions(input, output);
-  }
-  if (!s.ok()) {
+  } else if (!s.ok()) {
     return s;
   }
-  input = *output;
-  TF_RETURN_IF_ERROR(RootDataset::FromOptions(input, output));
-  input->Unref();
+  if (!rewritten) {
+    return RootDataset::FromOptions(input, output);
+  } else {
+    input = *output;
+    core::RefCountPtr<DatasetBase> dataset_ref_ptr(
+        const_cast<DatasetBase*>(input));
+    return RootDataset::FromOptions(std::move(dataset_ref_ptr), output);
+  }
   return Status::OK();
 }
+
 #else   // !IS_MOBILE_PLATFORM
 Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input,
                        DatasetBase** output) {
diff --git a/tensorflow/core/data/root_dataset.h b/tensorflow/core/data/root_dataset.h
index 11e28e4..70c9708 100644
--- a/tensorflow/core/data/root_dataset.h
+++ b/tensorflow/core/data/root_dataset.h
@@ -18,6 +18,7 @@
 #include "tensorflow/core/framework/dataset.h"
 #include "tensorflow/core/framework/model.h"
 #include "tensorflow/core/framework/model.pb.h"
+#include "tensorflow/core/platform/refcount.h"
 
 namespace tensorflow {
 namespace data {
@@ -36,6 +37,8 @@
   };
 
   static Status FromOptions(const DatasetBase* input, DatasetBase** output);
+  static Status FromOptions(core::RefCountPtr<DatasetBase> input,
+                            DatasetBase** output);
 
   ~RootDataset() override;
 
@@ -60,9 +63,12 @@
  private:
   class Iterator;
 
-  RootDataset(const DatasetBase* input, Params params);
+  RootDataset(const DatasetBase* input, const Params& params);
+
+  RootDataset(core::RefCountPtr<DatasetBase> input, const Params& params);
 
   const DatasetBase* input_;
+  core::RefCountPtr<DatasetBase> owned_input_;
   const Params params_;
   TraceMeMetadata traceme_metadata_;
 };
@@ -70,7 +76,8 @@
 // Finalizes the `input` dataset, which is expected to be called before the
 // dataset is about to be iterated. This can for instance apply static graph
 // optimizations or inject internal tf.data transformations responsible for
-// autotuning or threading configuration.
+// autotuning or threading configuration. The caller must ensure that the
+// input dataset to be finalized outlives the output.
 Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input,
                        DatasetBase** output);
 
diff --git a/tensorflow/core/data/standalone.cc b/tensorflow/core/data/standalone.cc
index 97ee2a0..8e0abe8 100644
--- a/tensorflow/core/data/standalone.cc
+++ b/tensorflow/core/data/standalone.cc
@@ -116,9 +116,9 @@
   OpKernelContext ctx(&op_params, /*num_outputs=*/0);
   TF_RETURN_IF_ERROR(data::FinalizeDataset(&ctx, dataset, &finalized_dataset));
   core::ScopedUnref unref(finalized_dataset);
-  *result = WrapUnique(new Dataset(finalized_dataset, device_mgr.release(),
-                                   pflr.release(), flib_def.release(),
-                                   pool.release(), std::move(runner)));
+  *result = WrapUnique(new Dataset(
+      finalized_dataset, dataset, device_mgr.release(), pflr.release(),
+      flib_def.release(), pool.release(), std::move(runner)));
   return Status::OK();
 }  // static
 
@@ -146,8 +146,8 @@
 
   // Create the iterator from the dataset.
   std::unique_ptr<IteratorBase> iterator;
-  TF_RETURN_IF_ERROR(dataset_->MakeIterator(ctx.get(), /*parent=*/nullptr,
-                                            "Iterator", &iterator));
+  TF_RETURN_IF_ERROR(finalized_dataset_->MakeIterator(
+      ctx.get(), /*parent=*/nullptr, "Iterator", &iterator));
   *result = WrapUnique(new Iterator(iterator.release(), ctx.release()));
 
   return Status::OK();
@@ -159,28 +159,33 @@
 
 Status Dataset::MakeSplitProviders(
     std::vector<std::unique_ptr<SplitProvider>>* result) {
-  return dataset_->MakeSplitProviders(result);
+  return finalized_dataset_->MakeSplitProviders(result);
 }
 
-const DatasetBase* Dataset::Get() const { return dataset_; }
+const DatasetBase* Dataset::Get() const { return finalized_dataset_; }
 
-Dataset::Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,
-                 ProcessFunctionLibraryRuntime* pflr,
+Dataset::Dataset(DatasetBase* finalized_dataset, DatasetBase* original_dataset,
+                 DeviceMgr* device_mgr, ProcessFunctionLibraryRuntime* pflr,
                  FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool,
                  std::function<void(std::function<void()>)> runner)
-    : dataset_(dataset),
+    : finalized_dataset_(finalized_dataset),
+      original_dataset_(original_dataset),
       device_mgr_(device_mgr),
       flib_def_(flib_def),
       pflr_(pflr),
       interop_threadpool_(pool),
       runner_(std::move(runner)),
       unbounded_thread_pool_(Env::Default(), "tf_data_standalone") {
-  dataset_->Ref();
+  finalized_dataset_->Ref();
+  original_dataset_->Ref();
   function_handle_cache_ =
       absl::make_unique<FunctionHandleCache>(pflr_->GetFLR("/device:CPU:0"));
 }
 
-Dataset::~Dataset() { dataset_->Unref(); }
+Dataset::~Dataset() {
+  finalized_dataset_->Unref();
+  original_dataset_->Unref();
+}
 
 }  // namespace standalone
 }  // namespace data
diff --git a/tensorflow/core/data/standalone.h b/tensorflow/core/data/standalone.h
index c641cac..c35e6a5 100644
--- a/tensorflow/core/data/standalone.h
+++ b/tensorflow/core/data/standalone.h
@@ -112,12 +112,13 @@
   const DatasetBase* Get() const;
 
  private:
-  Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,
-          ProcessFunctionLibraryRuntime* pflr,
+  Dataset(DatasetBase* finalized_dataset, DatasetBase* original_dataset,
+          DeviceMgr* device_mgr, ProcessFunctionLibraryRuntime* pflr,
           FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool,
           std::function<void(std::function<void()>)> runner);
 
-  DatasetBase* dataset_;  // owned
+  DatasetBase* finalized_dataset_;  // owned
+  DatasetBase* original_dataset_;   // owned
   std::unique_ptr<DeviceMgr> device_mgr_;
   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index 79c89a6..3533e4a 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -646,7 +646,7 @@
 StatusOr<DatasetBase*> DatasetBase::Finalize(
     OpKernelContext* ctx,
     std::function<StatusOr<core::RefCountPtr<DatasetBase>>()>
-        make_finalized_dataset) {
+        make_finalized_dataset) const {
   mutex_lock l(mu_);
   if (!finalized_dataset_) {
     TF_ASSIGN_OR_RETURN(finalized_dataset_, make_finalized_dataset());
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 36db239..120d206 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -1032,7 +1032,7 @@
   virtual StatusOr<DatasetBase*> Finalize(
       OpKernelContext* ctx,
       std::function<StatusOr<core::RefCountPtr<DatasetBase>>()>
-          make_finalized_dataset);
+          make_finalized_dataset) const;
 
   // Wrapper around a GraphDefBuilder which provides support for serializing
   // Datasets as GraphDefs.
@@ -1099,9 +1099,9 @@
   const string node_name_;
   Metadata metadata_;
   Options options_;
-  mutex mu_;
+  mutable mutex mu_;
   mutable mutex cardinality_mu_;
-  core::RefCountPtr<DatasetBase> finalized_dataset_;
+  mutable core::RefCountPtr<DatasetBase> finalized_dataset_;
   //  The number of source datasets feeding into the dataset. A source dataset
   //  is a leaf in the subtree of dataset inputs.
   int64_t num_sources_ = -1;
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 99bbe51..bcafc82 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -58,6 +58,7 @@
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core/data:dataset_utils",
         "//tensorflow/core/data:name_utils",
         "//tensorflow/core/data:serialization_utils",
         "//tensorflow/core/framework:dataset_options_proto_cc",
@@ -419,6 +420,7 @@
         "//tensorflow/core:session_options",
         "//tensorflow/core/data:captured_function",
         "//tensorflow/core/data:dataset_utils",
+        "//tensorflow/core/data:finalization_utils",
         "//tensorflow/core/data:root_dataset",
         "//tensorflow/core/data:serialization_utils",
         "//tensorflow/core/data:unbounded_thread_pool",
@@ -530,6 +532,7 @@
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core/data:dataset_utils",
+        "//tensorflow/core/data:finalization_utils",
         "//tensorflow/core/data:root_dataset",
         "//tensorflow/core/data:unbounded_thread_pool",
         "//tensorflow/core/kernels:ops_util",
@@ -1367,6 +1370,7 @@
     srcs = [
         "//tensorflow/core/data:captured_function.h",
         "//tensorflow/core/data:dataset_utils.h",
+        "//tensorflow/core/data:finalization_utils.h",
         "//tensorflow/core/data:utils.h",
         "//tensorflow/core/data:name_utils.h",
         "//tensorflow/core/data:rewrite_utils.h",
@@ -1391,6 +1395,7 @@
         ":portable_all_op_kernels_headers",
         "//tensorflow/core/data:captured_function.cc",
         "//tensorflow/core/data:dataset_utils.cc",
+        "//tensorflow/core/data:finalization_utils.cc",
         "//tensorflow/core/data:name_utils.cc",
         "//tensorflow/core/data:utils.cc",
         "//tensorflow/core/data:rewrite_utils.cc",
diff --git a/tensorflow/core/kernels/data/experimental/random_access_ops.h b/tensorflow/core/kernels/data/experimental/random_access_ops.h
index 3450ff5..bcc54ba 100644
--- a/tensorflow/core/kernels/data/experimental/random_access_ops.h
+++ b/tensorflow/core/kernels/data/experimental/random_access_ops.h
@@ -34,6 +34,8 @@
     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
   }
 
+  ~GetElementAtIndexOp() override {}
+
  protected:
   Status DoCompute(OpKernelContext* ctx) override;
 
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 5511d9f..5bb26d7 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -27,6 +27,7 @@
 #include "tensorflow/core/common_runtime/threadpool_device.h"
 #include "tensorflow/core/data/captured_function.h"
 #include "tensorflow/core/data/dataset_utils.h"
+#include "tensorflow/core/data/finalization_utils.h"
 #include "tensorflow/core/data/root_dataset.h"
 #include "tensorflow/core/data/serialization_utils.h"
 #include "tensorflow/core/data/utils.h"
@@ -194,6 +195,7 @@
                                  IteratorStateReader* reader) {
   const DatasetBase* dataset;
   std::shared_ptr<State> new_state;
+  const DatasetBase* input_dataset;
   {
     tf_shared_lock l(mu_);
     if (!iterator_state_->iterator()) {
@@ -211,6 +213,7 @@
         std::make_shared<State>(iterator_state_->flib_def(),
                                 iterator_state_->pflr(), iterator_state_->flr(),
                                 /*iterator=*/nullptr);
+    input_dataset = iterator_state_->dataset();
   }
   core::ScopedUnref scoped_unref(dataset);
   IteratorContext::Params params(ctx);
@@ -229,7 +232,8 @@
   std::unique_ptr<IteratorBase> iterator_base;
   TF_RETURN_IF_ERROR(dataset->MakeIteratorFromCheckpoint(
       IteratorContext(std::move(params)), "Iterator", reader, &iterator_base));
-  new_state->DowncastAndSetIterator(std::move(iterator_base));
+  new_state->DowncastAndSetIteratorAndDataset(std::move(iterator_base),
+                                              input_dataset);
 
   mutex_lock l(mu_);
   std::swap(iterator_state_, new_state);
@@ -265,8 +269,7 @@
   std::unique_ptr<IteratorBase> iterator;
   if (ctx->function_library()->device()->device_type() == DEVICE_CPU) {
     DatasetBase* finalized_dataset;
-    TF_RETURN_IF_ERROR(FinalizeDataset(ctx, dataset, &finalized_dataset));
-    core::ScopedUnref unref(finalized_dataset);
+    TF_ASSIGN_OR_RETURN(finalized_dataset, GetFinalizedDataset(ctx, dataset));
     TF_RETURN_IF_ERROR(finalized_dataset->MakeIterator(
         IteratorContext(std::move(params)),
         /*parent=*/nullptr, "Iterator", &iterator));
@@ -280,7 +283,7 @@
   TF_RETURN_IF_ERROR(
       VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
 
-  new_state->DowncastAndSetIterator(std::move(iterator));
+  new_state->DowncastAndSetIteratorAndDataset(std::move(iterator), dataset);
 
   mutex_lock l(mu_);
   std::swap(iterator_state_, new_state);
diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h
index 6fc9339..5c3188e 100644
--- a/tensorflow/core/kernels/data/iterator_ops.h
+++ b/tensorflow/core/kernels/data/iterator_ops.h
@@ -16,6 +16,8 @@
 #ifndef TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
 #define TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
 
+#include <utility>
+
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/data/dataset_utils.h"
 #include "tensorflow/core/data/unbounded_thread_pool.h"
@@ -26,6 +28,7 @@
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/refcount.h"
 
 namespace tensorflow {
 namespace data {
@@ -89,9 +92,14 @@
     ~State() { cancellation_manager_.StartCancel(); }
 
     // Downcasts the given `IteratorBase` to a `DatasetBaseIterator`, and uses
-    // it to set the `iterator` field.
-    void DowncastAndSetIterator(std::unique_ptr<IteratorBase> it) {
+    // it to set the `iterator` and the `dataset` field.
+    void DowncastAndSetIteratorAndDataset(std::unique_ptr<IteratorBase> it,
+                                          const DatasetBase* dataset) {
       iterator_.reset(static_cast<DatasetBaseIterator*>(it.release()));
+      if (dataset) {
+        dataset->Ref();
+        dataset_.reset(const_cast<DatasetBase*>(dataset));
+      }
     }
 
     std::shared_ptr<FunctionLibraryDefinition> flib_def() { return flib_def_; }
@@ -112,6 +120,8 @@
 
     DatasetBaseIterator* iterator() { return iterator_.get(); }
 
+    DatasetBase* dataset() { return dataset_.get(); }
+
    private:
     std::shared_ptr<FunctionLibraryDefinition> flib_def_;
     FunctionLibraryRuntime* flr_ = nullptr;  // not owned
@@ -120,6 +130,7 @@
     ResourceMgr resource_mgr_;
     CancellationManager cancellation_manager_;
     std::unique_ptr<DatasetBaseIterator> iterator_;
+    core::RefCountPtr<DatasetBase> dataset_;
   };
 
   UnboundedThreadPool unbounded_thread_pool_;
diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
index 7f3fdc1..3151e04 100644
--- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
+++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
@@ -92,7 +92,7 @@
   }
 
   Status Init(std::unique_ptr<IteratorBase> iterator, int64_t max_buffer_size,
-              int64_t* incarnation_id) {
+              int64_t* incarnation_id, DatasetBase* dataset) {
     if (iterator) {
       TF_RETURN_IF_ERROR(
           VerifyTypesMatch(output_types_, iterator->output_dtypes()));
@@ -104,6 +104,8 @@
     if (multi_device_buffer_) {
       multi_device_buffer_->Reset();
     }
+    dataset->Ref();
+    dataset_.reset(dataset);
 
     ++incarnation_id_;
     *incarnation_id = incarnation_id_;
@@ -450,6 +452,7 @@
 
   int64_t incarnation_id_ TF_GUARDED_BY(mu_) = 0;
   std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ TF_GUARDED_BY(mu_);
+  core::RefCountPtr<DatasetBase> dataset_;
 };
 
 // Used to generate unique names for anonymous multi device iterators.
@@ -654,7 +657,7 @@
     core::ScopedUnref unref(finalized_dataset);
     int64_t incarnation_id;
     OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
-                                       &incarnation_id));
+                                       &incarnation_id, dataset));
     Tensor tensor_incarnation_id(DT_INT64, TensorShape({}));
     tensor_incarnation_id.scalar<int64_t>()() = incarnation_id;
     OP_REQUIRES_OK(ctx,
@@ -825,9 +828,9 @@
 
   void Compute(OpKernelContext* ctx) override {
     ResourceHandle handle = ctx->input(0).flat<ResourceHandle>()(0);
-    // The iterator resource is guaranteed to exist because the variant tensor
-    // wrapping the deleter is provided as an unused input to this op, which
-    // guarantees that it has not run yet.
+    // The iterator resource is guaranteed to
+    // exist because the variant tensor wrapping the deleter is provided as an
+    // unused input to this op, which guarantees that it has not run yet.
     OP_REQUIRES_OK(ctx, DeleteResource(ctx, handle));
   }
 };
diff --git a/tensorflow/python/data/kernel_tests/shuffle_test.py b/tensorflow/python/data/kernel_tests/shuffle_test.py
index 6f442ea..3669dfa 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_test.py
@@ -377,23 +377,27 @@
       self.evaluate(get_next())
 
   @combinations.generate(
-      combinations.times(
-          test_base.default_test_combinations(),
-          combinations.combine(reshuffle=[True, False])))
-  def testRerandomizeOnReplicate(self, reshuffle):
+      combinations.times(test_base.default_test_combinations(),
+                         combinations.combine(reshuffle=[True, False])))
+  def testDontRerandomizeOnReplicate(self, reshuffle):
     random_seed.set_random_seed(None)
-    # When no seeds are fixed, each instantiation of the shuffle dataset should
-    # produce elements in a different order.
+    # Since the seed generator configuration is preserved across serialization
+    # of the dataset, each instantiation of the shuffle dataset
+    # should preserve the shuffle order if reshuffle=False. To preserve the
+    # shuffle order, the original dataset must be kept alive, since if the
+    # original dataset was destroyed, its seeds would also be destroyed.
     num_elements = 100
-    dataset = dataset_ops.Dataset.range(num_elements)
-    dataset = dataset.shuffle(num_elements, reshuffle_each_iteration=reshuffle)
+    dataset_1 = dataset_ops.Dataset.range(num_elements)
+    dataset_2 = dataset_1.shuffle(
+        num_elements, reshuffle_each_iteration=reshuffle)
 
-    shuffle_1 = self.getDatasetOutput(dataset)
-    dataset = self.graphRoundTrip(dataset, allow_stateful=True)
-    shuffle_2 = self.getDatasetOutput(dataset)
+    shuffle_1 = self.getDatasetOutput(dataset_2)
+    dataset_3 = self.graphRoundTrip(dataset_2, allow_stateful=True)
+    shuffle_2 = self.getDatasetOutput(dataset_3)
 
     self.assertCountEqual(shuffle_1, shuffle_2)
-    self.assertNotEqual(shuffle_1, shuffle_2)
+    if reshuffle:
+      self.assertNotEqual(shuffle_1, shuffle_2)
 
   @combinations.generate(test_base.eager_only_combinations())
   def testCheckpointLargeShuffleBuffer(self):