blob: 30353a34672e1308c91e581f6e6d3b23db97d353 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_PLATFORM_REFCOUNT_H_
#define TENSORFLOW_CORE_PLATFORM_REFCOUNT_H_
#include <atomic>
#include <map>
#include <memory>
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
namespace core {
class RefCounted {
public:
// Initial reference count is one.
RefCounted();
// Increments reference count by one.
void Ref() const;
// Decrements reference count by one. If the count remains
// positive, returns false. When the count reaches zero, returns
// true and deletes this, in which case the caller must not access
// the object afterward.
bool Unref() const;
// Gets the current reference count.
int_fast32_t RefCount() const;
// Return whether the reference count is one.
// If the reference count is used in the conventional way, a
// reference count of 1 implies that the current thread owns the
// reference and no other thread shares it.
// This call performs the test for a reference count of one, and
// performs the memory barrier needed for the owning thread
// to act on the object, knowing that it has exclusive access to the
// object.
bool RefCountIsOne() const;
protected:
// Make destructor protected so that RefCounted objects cannot
// be instantiated directly. Only subclasses can be instantiated.
virtual ~RefCounted();
// Increments reference count by one if the object is not being destructed.
// This function is used by WeakRefCounted for securely acquiring a
// strong reference. It is only safe to call this as part of the weak
// reference implementation.
bool TryRef() const;
private:
mutable std::atomic_int_fast32_t ref_;
RefCounted(const RefCounted&) = delete;
void operator=(const RefCounted&) = delete;
};
// A deleter class to form a std::unique_ptr that unrefs objects.
struct RefCountDeleter {
void operator()(RefCounted* o) const { o->Unref(); }
};
// A unique_ptr that unrefs the owned object on destruction.
template <typename T>
using RefCountPtr = std::unique_ptr<T, RefCountDeleter>;
// Helper class to unref an object when out-of-scope.
class ScopedUnref {
public:
explicit ScopedUnref(const RefCounted* o) : obj_(o) {}
~ScopedUnref() {
if (obj_) obj_->Unref();
}
private:
const RefCounted* obj_;
ScopedUnref(const ScopedUnref&) = delete;
void operator=(const ScopedUnref&) = delete;
};
// Forward declaration for friend class of WeakRefCounted.
template <typename T>
class WeakPtr;
// A base class for RefCounted objects that allow weak references by WeakPtr.
// WeakRefCounted and every WeakPtr to it, each holds a strong reference to a
// WeakRefData.
//
// If the WeakRefCounted is valid, WeakPtr::GetNewRef() returns a new strong
// reference to the WeakRefCounted.
// If the WeakRefCounted is being destructed, `WeakRefCounted::ref_ == 0`;
// if the WeakRefcounted is already destructed,`WeakRefData::ptr == nullptr`.
// In either case, WeakPtr::GetNewRef() returns a nullptr.
class WeakRefCounted : public RefCounted {
public:
int WeakRefCount() const {
// Each weak ref owns one ref to data_, and *this owns the last one.
return data_->RefCount() - 1;
}
protected:
~WeakRefCounted() override { data_->Reset(); }
private:
struct WeakRefData : public RefCounted {
explicit WeakRefData(WeakRefCounted* ptr) : ptr(ptr) {}
mutable mutex mu;
WeakRefCounted* ptr TF_GUARDED_BY(mu);
void Reset() {
mutex_lock ml(mu);
ptr = nullptr;
}
WeakRefCounted* GetNewRef() {
mutex_lock ml(mu);
if (ptr != nullptr && ptr->TryRef()) {
return ptr;
}
return nullptr;
}
};
RefCountPtr<WeakRefData> data_{new WeakRefData(this)};
template <typename T>
friend class WeakPtr;
// MSVC14 workaround: access permission of a nested class member is not
// treated as an ordinary member in MSVC14.
friend struct WeakRefData;
};
// A weak reference to a WeakRefCounted object. See WeakRefCounted.
template <typename T>
class WeakPtr {
public:
WeakPtr() : data_(nullptr) {}
// Creates a weak reference to a WeakRefCounted ptr.
// ptr must be valid during the constructor.
explicit WeakPtr(WeakRefCounted* ptr) : data_(nullptr) {
if (ptr != nullptr) {
ptr->data_->Ref();
data_.reset(ptr->data_.get());
}
}
// Returns a new strong reference to the referred object, or nullptr if the
// object is in an invalid state (being destructed or already destructed).
RefCountPtr<T> GetNewRef() const {
RefCountPtr<T> ref;
if (data_ != nullptr) {
WeakRefCounted* ptr = data_->GetNewRef();
ref.reset(static_cast<T*>(ptr));
}
return std::move(ref);
}
private:
// NOTE(feyu): change this to a IntrusivePtr to make WeakPtr copiable.
RefCountPtr<WeakRefCounted::WeakRefData> data_;
};
// Inlined routines, since these are performance critical
inline RefCounted::RefCounted() : ref_(1) {}
inline RefCounted::~RefCounted() {
// A destructing object has ref_ == 0.
// It is a bug if the object is resurrected (ref_ > 0) before delete is
// called by Unref().
DCHECK_EQ(ref_.load(), 0);
}
inline void RefCounted::Ref() const {
// Ref() uses relaxed order because it is never called with old_ref == 0.
// When old_ref >= 1, no actions depend on the new value of ref.
int_fast32_t old_ref = ref_.fetch_add(1, std::memory_order_relaxed);
DCHECK_GT(old_ref, 0);
}
inline bool RefCounted::TryRef() const {
// This is not on a hot path.
// Be conservative and use seq_cst to prevent racing with Unref() when
// old_ref == 0, as done in LLVM libstdc++.
int_fast32_t old_ref = ref_.load();
while (old_ref != 0) {
if (ref_.compare_exchange_weak(old_ref, old_ref + 1)) {
return true;
}
}
// Already destructing, cannot increase ref.
return false;
}
inline bool RefCounted::Unref() const {
DCHECK_GT(ref_.load(), 0);
// acq_rel is used to prevent reordering introduces object access after
// destruction.
// Using release alone is a bug on systems where acq_rel differs from release.
// (e.g. arm), according to Herb Sutter's 2012 talk on "Atomic<> Weapons".
if (ref_.fetch_sub(1, std::memory_order_acq_rel) == 1) {
delete this;
return true;
}
return false;
}
inline int_fast32_t RefCounted::RefCount() const {
return ref_.load(std::memory_order_acquire);
}
inline bool RefCounted::RefCountIsOne() const {
return (ref_.load(std::memory_order_acquire) == 1);
}
} // namespace core
} // namespace tensorflow
#endif // TENSORFLOW_CORE_PLATFORM_REFCOUNT_H_