[tf.data] Memory-safe implementation of sharing access to the memory cache.
PiperOrigin-RevId: 307736215
Change-Id: If10ef65e6706a106e6bb4fc2d6fe4542bbe056cc
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index 5f7dedc..9a1a4ee 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -694,8 +694,10 @@
class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
public:
explicit MemoryDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
- MemoryCache* cache)
- : DatasetBase(DatasetContext(ctx)), input_(input), cache_(cache) {
+ std::shared_ptr<MemoryCache> cache)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ cache_(std::move(cache)) {
input_->Ref();
}
@@ -708,7 +710,7 @@
return absl::make_unique<MemoryIterator>(
MemoryIterator::Params{
this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
- cache_);
+ cache_.get());
}
const DataTypeVector& output_dtypes() const override {
@@ -964,7 +966,7 @@
}; // MemoryIterator
const DatasetBase* const input_;
- MemoryCache* const cache_;
+ const std::shared_ptr<MemoryCache> cache_;
}; // MemoryDatasetBase
// This version of memory dataset has an exclusive ownership of the memory cache
@@ -973,22 +975,19 @@
class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase {
public:
MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
- MemoryCache* cache, const ResourceHandle& resource_handle)
- : MemoryDatasetBase(ctx, input, cache),
- resource_handle_(resource_handle) {
- cleanup_ = [this, mgr = ctx->resource_manager()]() {
- DCHECK(cache_->RefCountIsOne());
- Status s = mgr->Delete<MemoryCache>(resource_handle_.container(),
- resource_handle_.name());
- if (!s.ok()) {
- LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
- }
- };
- }
+ MemoryCacheManager* manager, ResourceHandle&& resource_handle)
+ : MemoryDatasetBase(ctx, input, manager->get()),
+ manager_(manager),
+ resource_handle_(std::move(resource_handle)),
+ resource_mgr_(ctx->resource_manager()) {}
~MemoryDataset() override {
- cache_->Unref();
- cleanup_();
+ manager_->Unref();
+ Status s = resource_mgr_->Delete<MemoryCacheManager>(
+ resource_handle_.container(), resource_handle_.name());
+ if (!s.ok()) {
+ LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
+ }
}
protected:
@@ -1005,8 +1004,9 @@
}
private:
- std::function<void()> cleanup_;
+ MemoryCacheManager* const manager_; // Owned.
const ResourceHandle resource_handle_;
+ ResourceMgr* const resource_mgr_; // Not owned.
};
// This version of memory dataset has a shared ownership of the memory cache
@@ -1016,28 +1016,23 @@
: public CacheDatasetOp::MemoryDatasetBase {
public:
MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
- MemoryCache* cache, const ResourceHandle& resource_handle)
- : MemoryDatasetBase(ctx, input, cache),
- resource_handle_(std::move(resource_handle)) {
- cleanup_ = [this, mgr = ctx->resource_manager()]() {
- if (cache_->RefCountIsOne()) {
- Status s = mgr->Delete<MemoryCache>(resource_handle_.container(),
- resource_handle_.name());
- if (!s.ok()) {
- if (errors::IsNotFound(s)) {
- // This is a bening race resulting from concurrent deletion.
- VLOG(1) << "Failed to delete cache resource: " << s.ToString();
- } else {
- LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
- }
- }
- }
- };
- }
+ MemoryCacheManager* manager, ResourceHandle&& resource_handle,
+ bool owns_resource)
+ : MemoryDatasetBase(ctx, input, manager->get()),
+ manager_(manager),
+ owns_resource_(owns_resource),
+ resource_handle_(std::move(resource_handle)),
+ resource_mgr_(ctx->resource_manager()) {}
~MemoryDatasetV2() override {
- cache_->Unref();
- cleanup_();
+ manager_->Unref();
+ if (owns_resource_) {
+ Status s = resource_mgr_->Delete<MemoryCacheManager>(
+ resource_handle_.container(), resource_handle_.name());
+ if (!s.ok()) {
+ LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
+ }
+ }
}
protected:
@@ -1058,8 +1053,10 @@
}
private:
- std::function<void()> cleanup_;
+ MemoryCacheManager* const manager_; // Owned.
+ const bool owns_resource_;
const ResourceHandle resource_handle_;
+ ResourceMgr* const resource_mgr_; // Not owned.
};
CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx)
@@ -1077,33 +1074,39 @@
auto name = strings::StrCat(ctx->op_kernel().name(), "/", kMemoryCache, "_",
resource_id_counter.fetch_add(1));
if (op_version_ == 2) {
- MemoryCache* cache = nullptr;
+ bool owns_resource = false;
+ MemoryCacheManager* manager = nullptr;
auto handle = HandleFromInput(ctx, 2);
- Status s = ctx->resource_manager()->Lookup<MemoryCache>(
- handle.container(), handle.name(), &cache);
+ Status s = ctx->resource_manager()->Lookup<MemoryCacheManager>(
+ handle.container(), handle.name(), &manager);
if (errors::IsNotFound(s)) {
- OP_REQUIRES_OK(ctx,
- ctx->resource_manager()->LookupOrCreate<MemoryCache>(
- container, name, &cache, [](MemoryCache** cache) {
- *cache = new MemoryCache();
- return Status::OK();
- }));
- handle = MakeResourceHandle<MemoryCache>(ctx, container, name);
+ owns_resource = true;
+ OP_REQUIRES_OK(
+ ctx,
+ ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
+ container, name, &manager, [](MemoryCacheManager** manager) {
+ *manager = new MemoryCacheManager();
+ return Status::OK();
+ }));
+ handle = MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
} else {
OP_REQUIRES_OK(ctx, s);
}
- // Ownership of cache is transferred onto `MemoryDatasetV2`.
- *output = new MemoryDatasetV2(ctx, input, cache, std::move(handle));
+ // Ownership of manager is transferred onto `MemoryDatasetV2`.
+ *output = new MemoryDatasetV2(ctx, input, manager, std::move(handle),
+ owns_resource);
} else {
- MemoryCache* cache;
- OP_REQUIRES_OK(ctx, ctx->resource_manager()->LookupOrCreate<MemoryCache>(
- container, name, &cache, [](MemoryCache** cache) {
- *cache = new MemoryCache();
- return Status::OK();
- }));
- auto handle = MakeResourceHandle<MemoryCache>(ctx, container, name);
- // Ownership of cache is transferred onto `MemoryDataset`.
- *output = new MemoryDataset(ctx, input, cache, handle);
+ MemoryCacheManager* manager;
+ OP_REQUIRES_OK(
+ ctx, ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
+ container, name, &manager, [](MemoryCacheManager** manager) {
+ *manager = new MemoryCacheManager();
+ return Status::OK();
+ }));
+ auto handle =
+ MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
+ // Ownership of manager is transferred onto `MemoryDataset`.
+ *output = new MemoryDataset(ctx, input, manager, std::move(handle));
}
} else {
if (op_version_ == 2) {
diff --git a/tensorflow/core/kernels/data/cache_ops.cc b/tensorflow/core/kernels/data/cache_ops.cc
index 8b58e7b..90c2e90 100644
--- a/tensorflow/core/kernels/data/cache_ops.cc
+++ b/tensorflow/core/kernels/data/cache_ops.cc
@@ -31,7 +31,7 @@
} // namespace
-string MemoryCache::DebugString() const { return kMemoryCache; }
+string MemoryCacheManager::DebugString() const { return kMemoryCache; }
void MemoryCache::Complete(std::vector<std::vector<Tensor>>&& cache) {
mutex_lock l(mu_);
@@ -65,19 +65,15 @@
AnonymousMemoryCacheHandleOp::AnonymousMemoryCacheHandleOp(
OpKernelConstruction* ctx)
- : AnonymousResourceOp<MemoryCache>(ctx) {}
-
-void AnonymousMemoryCacheHandleOp::Compute(OpKernelContext* ctx) {
- AnonymousResourceOp<MemoryCache>::Compute(ctx);
-}
+ : AnonymousResourceOp<MemoryCacheManager>(ctx) {}
string AnonymousMemoryCacheHandleOp::name() { return kMemoryCache; }
Status AnonymousMemoryCacheHandleOp::CreateResource(
OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
- FunctionLibraryRuntime* lib, MemoryCache** resource) {
- *resource = new MemoryCache();
+ FunctionLibraryRuntime* lib, MemoryCacheManager** manager) {
+ *manager = new MemoryCacheManager();
return Status::OK();
}
diff --git a/tensorflow/core/kernels/data/cache_ops.h b/tensorflow/core/kernels/data/cache_ops.h
index d21679b..c670d6f 100644
--- a/tensorflow/core/kernels/data/cache_ops.h
+++ b/tensorflow/core/kernels/data/cache_ops.h
@@ -27,12 +27,10 @@
// The expected use is that a single `MemoryWriterIterator` populates the
// cache with dataset elements. Once all elements are cached, the cache can
// be used by one or more `MemoryReaderIterator`s.
-class MemoryCache : public ResourceBase {
+class MemoryCache {
public:
MemoryCache() = default;
- string DebugString() const override;
-
// Marks the cache as completed.
void Complete(std::vector<std::vector<Tensor>>&& cache);
@@ -55,11 +53,24 @@
std::vector<std::vector<Tensor>> cache_ TF_GUARDED_BY(mu_);
};
+// A resource wrapping a shared instance of a memory cache.
+class MemoryCacheManager : public ResourceBase {
+ public:
+ MemoryCacheManager() : cache_(std::make_shared<MemoryCache>()) {}
+
+ string DebugString() const override;
+
+ std::shared_ptr<MemoryCache> get() { return cache_; }
+
+ private:
+ std::shared_ptr<MemoryCache> cache_;
+};
+
// Creates an instance of cache resource and transfers ownership to the caller.
-class AnonymousMemoryCacheHandleOp : public AnonymousResourceOp<MemoryCache> {
+class AnonymousMemoryCacheHandleOp
+ : public AnonymousResourceOp<MemoryCacheManager> {
public:
explicit AnonymousMemoryCacheHandleOp(OpKernelConstruction* ctx);
- void Compute(OpKernelContext* ctx) override;
private:
string name() override;
@@ -67,7 +78,7 @@
std::unique_ptr<FunctionLibraryDefinition> flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
FunctionLibraryRuntime* lib,
- MemoryCache** resource) override;
+ MemoryCacheManager** manager) override;
};
// Deletes an instance of cache resource.