[tf.data] Memory-safe implementation of sharing access to the memory cache.

PiperOrigin-RevId: 307736215
Change-Id: If10ef65e6706a106e6bb4fc2d6fe4542bbe056cc
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index 5f7dedc..9a1a4ee 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -694,8 +694,10 @@
 class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
  public:
   explicit MemoryDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
-                             MemoryCache* cache)
-      : DatasetBase(DatasetContext(ctx)), input_(input), cache_(cache) {
+                             std::shared_ptr<MemoryCache> cache)
+      : DatasetBase(DatasetContext(ctx)),
+        input_(input),
+        cache_(std::move(cache)) {
     input_->Ref();
   }
 
@@ -708,7 +710,7 @@
     return absl::make_unique<MemoryIterator>(
         MemoryIterator::Params{
             this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
-        cache_);
+        cache_.get());
   }
 
   const DataTypeVector& output_dtypes() const override {
@@ -964,7 +966,7 @@
   };  // MemoryIterator
 
   const DatasetBase* const input_;
-  MemoryCache* const cache_;
+  const std::shared_ptr<MemoryCache> cache_;
 };  // MemoryDatasetBase
 
 // This version of memory dataset has an exclusive ownership of the memory cache
@@ -973,22 +975,19 @@
 class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase {
  public:
   MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
-                MemoryCache* cache, const ResourceHandle& resource_handle)
-      : MemoryDatasetBase(ctx, input, cache),
-        resource_handle_(resource_handle) {
-    cleanup_ = [this, mgr = ctx->resource_manager()]() {
-      DCHECK(cache_->RefCountIsOne());
-      Status s = mgr->Delete<MemoryCache>(resource_handle_.container(),
-                                          resource_handle_.name());
-      if (!s.ok()) {
-        LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
-      }
-    };
-  }
+                MemoryCacheManager* manager, ResourceHandle&& resource_handle)
+      : MemoryDatasetBase(ctx, input, manager->get()),
+        manager_(manager),
+        resource_handle_(std::move(resource_handle)),
+        resource_mgr_(ctx->resource_manager()) {}
 
   ~MemoryDataset() override {
-    cache_->Unref();
-    cleanup_();
+    manager_->Unref();
+    Status s = resource_mgr_->Delete<MemoryCacheManager>(
+        resource_handle_.container(), resource_handle_.name());
+    if (!s.ok()) {
+      LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
+    }
   }
 
  protected:
@@ -1005,8 +1004,9 @@
   }
 
  private:
-  std::function<void()> cleanup_;
+  MemoryCacheManager* const manager_;  // Owned.
   const ResourceHandle resource_handle_;
+  ResourceMgr* const resource_mgr_;  // Not owned.
 };
 
 // This version of memory dataset has a shared ownership of the memory cache
@@ -1016,28 +1016,23 @@
     : public CacheDatasetOp::MemoryDatasetBase {
  public:
   MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
-                  MemoryCache* cache, const ResourceHandle& resource_handle)
-      : MemoryDatasetBase(ctx, input, cache),
-        resource_handle_(std::move(resource_handle)) {
-    cleanup_ = [this, mgr = ctx->resource_manager()]() {
-      if (cache_->RefCountIsOne()) {
-        Status s = mgr->Delete<MemoryCache>(resource_handle_.container(),
-                                            resource_handle_.name());
-        if (!s.ok()) {
-          if (errors::IsNotFound(s)) {
-            // This is a bening race resulting from concurrent deletion.
-            VLOG(1) << "Failed to delete cache resource: " << s.ToString();
-          } else {
-            LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
-          }
-        }
-      }
-    };
-  }
+                  MemoryCacheManager* manager, ResourceHandle&& resource_handle,
+                  bool owns_resource)
+      : MemoryDatasetBase(ctx, input, manager->get()),
+        manager_(manager),
+        owns_resource_(owns_resource),
+        resource_handle_(std::move(resource_handle)),
+        resource_mgr_(ctx->resource_manager()) {}
 
   ~MemoryDatasetV2() override {
-    cache_->Unref();
-    cleanup_();
+    manager_->Unref();
+    if (owns_resource_) {
+      Status s = resource_mgr_->Delete<MemoryCacheManager>(
+          resource_handle_.container(), resource_handle_.name());
+      if (!s.ok()) {
+        LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
+      }
+    }
   }
 
  protected:
@@ -1058,8 +1053,10 @@
   }
 
  private:
-  std::function<void()> cleanup_;
+  MemoryCacheManager* const manager_;  // Owned.
+  const bool owns_resource_;
   const ResourceHandle resource_handle_;
+  ResourceMgr* const resource_mgr_;  // Not owned.
 };
 
 CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx)
@@ -1077,33 +1074,39 @@
     auto name = strings::StrCat(ctx->op_kernel().name(), "/", kMemoryCache, "_",
                                 resource_id_counter.fetch_add(1));
     if (op_version_ == 2) {
-      MemoryCache* cache = nullptr;
+      bool owns_resource = false;
+      MemoryCacheManager* manager = nullptr;
       auto handle = HandleFromInput(ctx, 2);
-      Status s = ctx->resource_manager()->Lookup<MemoryCache>(
-          handle.container(), handle.name(), &cache);
+      Status s = ctx->resource_manager()->Lookup<MemoryCacheManager>(
+          handle.container(), handle.name(), &manager);
       if (errors::IsNotFound(s)) {
-        OP_REQUIRES_OK(ctx,
-                       ctx->resource_manager()->LookupOrCreate<MemoryCache>(
-                           container, name, &cache, [](MemoryCache** cache) {
-                             *cache = new MemoryCache();
-                             return Status::OK();
-                           }));
-        handle = MakeResourceHandle<MemoryCache>(ctx, container, name);
+        owns_resource = true;
+        OP_REQUIRES_OK(
+            ctx,
+            ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
+                container, name, &manager, [](MemoryCacheManager** manager) {
+                  *manager = new MemoryCacheManager();
+                  return Status::OK();
+                }));
+        handle = MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
       } else {
         OP_REQUIRES_OK(ctx, s);
       }
-      // Ownership of cache is transferred onto `MemoryDatasetV2`.
-      *output = new MemoryDatasetV2(ctx, input, cache, std::move(handle));
+      // Ownership of manager is transferred onto `MemoryDatasetV2`.
+      *output = new MemoryDatasetV2(ctx, input, manager, std::move(handle),
+                                    owns_resource);
     } else {
-      MemoryCache* cache;
-      OP_REQUIRES_OK(ctx, ctx->resource_manager()->LookupOrCreate<MemoryCache>(
-                              container, name, &cache, [](MemoryCache** cache) {
-                                *cache = new MemoryCache();
-                                return Status::OK();
-                              }));
-      auto handle = MakeResourceHandle<MemoryCache>(ctx, container, name);
-      // Ownership of cache is transferred onto `MemoryDataset`.
-      *output = new MemoryDataset(ctx, input, cache, handle);
+      MemoryCacheManager* manager;
+      OP_REQUIRES_OK(
+          ctx, ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
+                   container, name, &manager, [](MemoryCacheManager** manager) {
+                     *manager = new MemoryCacheManager();
+                     return Status::OK();
+                   }));
+      auto handle =
+          MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
+      // Ownership of manager is transferred onto `MemoryDataset`.
+      *output = new MemoryDataset(ctx, input, manager, std::move(handle));
     }
   } else {
     if (op_version_ == 2) {
diff --git a/tensorflow/core/kernels/data/cache_ops.cc b/tensorflow/core/kernels/data/cache_ops.cc
index 8b58e7b..90c2e90 100644
--- a/tensorflow/core/kernels/data/cache_ops.cc
+++ b/tensorflow/core/kernels/data/cache_ops.cc
@@ -31,7 +31,7 @@
 
 }  // namespace
 
-string MemoryCache::DebugString() const { return kMemoryCache; }
+string MemoryCacheManager::DebugString() const { return kMemoryCache; }
 
 void MemoryCache::Complete(std::vector<std::vector<Tensor>>&& cache) {
   mutex_lock l(mu_);
@@ -65,19 +65,15 @@
 
 AnonymousMemoryCacheHandleOp::AnonymousMemoryCacheHandleOp(
     OpKernelConstruction* ctx)
-    : AnonymousResourceOp<MemoryCache>(ctx) {}
-
-void AnonymousMemoryCacheHandleOp::Compute(OpKernelContext* ctx) {
-  AnonymousResourceOp<MemoryCache>::Compute(ctx);
-}
+    : AnonymousResourceOp<MemoryCacheManager>(ctx) {}
 
 string AnonymousMemoryCacheHandleOp::name() { return kMemoryCache; }
 
 Status AnonymousMemoryCacheHandleOp::CreateResource(
     OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
-    FunctionLibraryRuntime* lib, MemoryCache** resource) {
-  *resource = new MemoryCache();
+    FunctionLibraryRuntime* lib, MemoryCacheManager** manager) {
+  *manager = new MemoryCacheManager();
   return Status::OK();
 }
 
diff --git a/tensorflow/core/kernels/data/cache_ops.h b/tensorflow/core/kernels/data/cache_ops.h
index d21679b..c670d6f 100644
--- a/tensorflow/core/kernels/data/cache_ops.h
+++ b/tensorflow/core/kernels/data/cache_ops.h
@@ -27,12 +27,10 @@
 // The expected use is that a single `MemoryWriterIterator` populates the
 // cache with dataset elements. Once all elements are cached, the cache can
 // be used by one or more `MemoryReaderIterator`s.
-class MemoryCache : public ResourceBase {
+class MemoryCache {
  public:
   MemoryCache() = default;
 
-  string DebugString() const override;
-
   // Marks the cache as completed.
   void Complete(std::vector<std::vector<Tensor>>&& cache);
 
@@ -55,11 +53,24 @@
   std::vector<std::vector<Tensor>> cache_ TF_GUARDED_BY(mu_);
 };
 
+// A resource wrapping a shared instance of a memory cache.
+class MemoryCacheManager : public ResourceBase {
+ public:
+  MemoryCacheManager() : cache_(std::make_shared<MemoryCache>()) {}
+
+  string DebugString() const override;
+
+  std::shared_ptr<MemoryCache> get() { return cache_; }
+
+ private:
+  std::shared_ptr<MemoryCache> cache_;
+};
+
 // Creates an instance of cache resource and transfers ownership to the caller.
-class AnonymousMemoryCacheHandleOp : public AnonymousResourceOp<MemoryCache> {
+class AnonymousMemoryCacheHandleOp
+    : public AnonymousResourceOp<MemoryCacheManager> {
  public:
   explicit AnonymousMemoryCacheHandleOp(OpKernelConstruction* ctx);
-  void Compute(OpKernelContext* ctx) override;
 
  private:
   string name() override;
@@ -67,7 +78,7 @@
                         std::unique_ptr<FunctionLibraryDefinition> flib_def,
                         std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
                         FunctionLibraryRuntime* lib,
-                        MemoryCache** resource) override;
+                        MemoryCacheManager** manager) override;
 };
 
 // Deletes an instance of cache resource.