Add a Weak-referable subclass to RefCounted.

Also move ResourceBase to WeakRefCounted, such that ResourceMgr can be aware of destroyed resources via WeakPtr.

A follow up CL will enroll other AnonymousResourceOp subclasses to refcounting Resource handles. Serialization of them (and publishing to ResourceMgr by name) is now fully supported via the new WeakPtr holders in ResourceMgr.

- WeakRefCounted is the new subclass of RefCounted. Instances of WeakRefCounted subclass T can be weakly referred by WeakPtr<T>.
- WeakPtr::GetNewRef() returns a new reference (as RefCountedPtr) to the object if the object is still valid, or nullptr if it is no longer valid.

The ref-counting mechanism of RefCounted is slightly modified to fix a race condition when ref_ == 1. In the presence of weak references, ref_ == 1 no longer indicates the calling thread is the sole owner of the object: a weak reference holding thread could have increased the reference count after this check.

PiperOrigin-RevId: 399276448
Change-Id: I46f552a5600514e539e1ec003f29e4ea8d5f61f6
diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD
index ece08f8..5ff2797 100644
--- a/tensorflow/core/framework/BUILD
+++ b/tensorflow/core/framework/BUILD
@@ -739,6 +739,7 @@
         "//tensorflow/core/platform:tensor_coding",
         "//tensorflow/core/platform:types",
         "//tensorflow/core/util:managed_stack_trace",
+        "@com_google_absl//absl/strings:str_format",
     ],
     alwayslink = 1,
 )
diff --git a/tensorflow/core/framework/resource_base.h b/tensorflow/core/framework/resource_base.h
index 43ac6a9e..518f6c2 100644
--- a/tensorflow/core/framework/resource_base.h
+++ b/tensorflow/core/framework/resource_base.h
@@ -30,7 +30,7 @@
 // This is the base class of all resource classes. Each resource must be
 // represented as a sub-class of ResourceBase (which is reference counted) to be
 // able to work with resource facilities such ResourceHandle and ResourceMgr.
-class ResourceBase : public core::RefCounted {
+class ResourceBase : public core::WeakRefCounted {
  public:
   // Returns a debug string for *this.
   virtual std::string DebugString() const = 0;
@@ -45,7 +45,6 @@
                                  DebugString());
   }
 };
-
 }  //  end namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_BASE_H_
diff --git a/tensorflow/core/framework/resource_handle.cc b/tensorflow/core/framework/resource_handle.cc
index 0ca9fc8..aaab872 100644
--- a/tensorflow/core/framework/resource_handle.cc
+++ b/tensorflow/core/framework/resource_handle.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/core/framework/resource_handle.h"
 
+#include "absl/strings/str_format.h"
 #include "tensorflow/core/framework/resource_handle.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/strings/strcat.h"
@@ -94,10 +95,12 @@
   ResourceHandle result;
   result.resource_.reset(resource, /*add_ref=*/false);
   result.set_device(device_name);
-  // "container" is only a ResourceMgr-only concept
-  result.set_container("");
+  // All resources owned by anonymous handles are put into the same container,
+  // and they get process-unique handle names.
+  result.set_container("Anonymous");
   result.set_definition_stack_trace(definition_stack_trace);
-  result.set_name(strings::StrCat("_AnonymousResource", GenerateUniqueId()));
+  result.set_name(
+      absl::StrFormat("Resource-%d-at-%p", GenerateUniqueId(), resource));
   result.set_hash_code(type_index.hash_code());
   result.set_maybe_type_name(type_index.name());
   result.set_dtypes_and_shapes(dtypes_and_shapes);
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index 3c96618..1387aad 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -95,31 +95,39 @@
 }
 
 ResourceMgr::ResourceAndName::ResourceAndName()
-    : resource(nullptr), name(nullptr), resource_owner(nullptr) {}
+    : resource(core::RefCountPtr<ResourceBase>{nullptr}), name(nullptr) {}
 
-ResourceMgr::ResourceAndName::ResourceAndName(ResourceBase* resource,
-                                              string name,
-                                              ResourceBase* resource_owner)
-    : resource(resource),
-      name(absl::make_unique<string>(std::move(name))),
-      resource_owner(resource_owner) {}
+ResourceMgr::ResourceAndName::ResourceAndName(
+    StrongOrWeakResourcePtr&& resource, string name)
+    : resource(std::move(resource)),
+      name(absl::make_unique<string>(std::move(name))) {}
+
+core::RefCountPtr<ResourceBase> ResourceMgr::ResourceAndName::GetResource()
+    const {
+  if (absl::holds_alternative<core::RefCountPtr<ResourceBase>>(resource)) {
+    ResourceBase* ptr =
+        absl::get<core::RefCountPtr<ResourceBase>>(resource).get();
+    ptr->Ref();
+    return core::RefCountPtr<ResourceBase>{ptr};
+  } else if (absl::holds_alternative<core::WeakPtr<ResourceBase>>(resource)) {
+    return absl::get<core::WeakPtr<ResourceBase>>(resource).GetNewRef();
+  } else {
+    return nullptr;
+  }
+}
 
 ResourceMgr::ResourceAndName::ResourceAndName(
     ResourceAndName&& other) noexcept {
-  resource = nullptr;
-  std::swap(resource, other.resource);
   name = std::move(other.name);
-  resource_owner = std::move(other.resource_owner);
+  resource = std::move(other.resource);
 }
 
 ResourceMgr::ResourceAndName::~ResourceAndName() {}
 
 ResourceMgr::ResourceAndName& ResourceMgr::ResourceAndName::operator=(
     ResourceAndName&& other) noexcept {
-  resource = nullptr;
-  std::swap(resource, other.resource);
   name = std::move(other.name);
-  resource_owner = std::move(other.resource_owner);
+  resource = std::move(other.resource);
   return *this;
 }
 
@@ -158,8 +166,9 @@
     for (const auto& q : *p.second) {
       const Key& key = q.first;
       const char* type = DebugTypeName(key.first);
+      const core::RefCountPtr<ResourceBase> resource = q.second.GetResource();
       Line l{&container, port::Demangle(type), q.second.name.get(),
-             q.second.resource->DebugString()};
+             resource ? resource->DebugString() : "<nullptr>"};
       lines.push_back(l);
     }
   }
@@ -184,8 +193,11 @@
 
   // NOTE: Separating out the construction of the map key and value so that the
   // key can contain a StringPiece that borrows from the string in the value.
-  ResourceAndName resource_and_name(resource, name,
-                                    owns_resource ? resource : nullptr);
+  ResourceAndName resource_and_name(
+      owns_resource
+          ? StrongOrWeakResourcePtr{core::RefCountPtr<ResourceBase>{resource}}
+          : StrongOrWeakResourcePtr{core::WeakPtr<ResourceBase>{resource}},
+      name);
   StringPiece borrowed_name(*resource_and_name.name);
   Container::value_type key_and_value(Key(type.hash_code(), borrowed_name),
                                       std::move(resource_and_name));
@@ -226,8 +238,32 @@
     return errors::NotFound("Resource ", container, "/", resource_name, "/",
                             type_name, " does not exist.");
   }
-  *resource = iter->second.resource;
-  (*resource)->Ref();
+  ResourceBase* ptr = iter->second.GetResource().release();
+  if (ptr == nullptr) {
+    return errors::NotFound("Resource ", container, "/", resource_name, "/",
+                            type_name, " has been destroyed.");
+  }
+  *resource = ptr;
+  return Status::OK();
+}
+
+Status ResourceMgr::PopResourceAndName(const string& container,
+                                       uint64 type_hash_code,
+                                       const string& resource_name,
+                                       const string& type_name,
+                                       ResourceAndName& resource_and_name) {
+  mutex_lock l(mu_);
+  Container* b = gtl::FindPtrOrNull(containers_, container);
+  if (b == nullptr) {
+    return errors::NotFound("Container ", container, " does not exist.");
+  }
+  auto iter = b->find({type_hash_code, resource_name});
+  if (iter == b->end()) {
+    return errors::NotFound("Resource ", container, "/", resource_name, "/",
+                            type_name, " does not exist.");
+  }
+  std::swap(resource_and_name, iter->second);
+  b->erase(iter);
   return Status::OK();
 }
 
@@ -235,21 +271,8 @@
                              const string& resource_name,
                              const string& type_name) {
   ResourceAndName resource_and_name;
-  {
-    mutex_lock l(mu_);
-    Container* b = gtl::FindPtrOrNull(containers_, container);
-    if (b == nullptr) {
-      return errors::NotFound("Container ", container, " does not exist.");
-    }
-    auto iter = b->find({type_hash_code, resource_name});
-    if (iter == b->end()) {
-      return errors::NotFound("Resource ", container, "/", resource_name, "/",
-                              type_name, " does not exist.");
-    }
-    std::swap(resource_and_name, iter->second);
-    b->erase(iter);
-  }
-  DCHECK(resource_and_name.resource != nullptr);
+  TF_RETURN_IF_ERROR(PopResourceAndName(
+      container, type_hash_code, resource_name, type_name, resource_and_name));
   return Status::OK();
 }
 
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index cb1a80a..d2be5eb 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -22,6 +22,7 @@
 #include <typeinfo>
 #include <unordered_map>
 
+#include "absl/types/variant.h"
 #include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/device_attributes.pb.h"
 #include "tensorflow/core/framework/op_kernel.h"
@@ -166,8 +167,9 @@
   // transfer the ownership of any ref on "resource" to *this, regardless of
   // whether this operation succeeds or fails.
   //
-  // The caller must ensure calling this->Delete() on the name before the
-  // resource is destroyed.
+  // After the resource is destroyed, lookups from the manager fail.
+  // The caller must call this->Delete() on the name to free up the memory
+  // entry of the name.
   //
   // REQUIRES: std::is_base_of<ResourceBase, T>
   // REQUIRES: resource != nullptr.
@@ -246,19 +248,25 @@
       return (x.second == y.second) && (x.first == y.first);
     }
   };
+  typedef absl::variant<core::RefCountPtr<ResourceBase>,
+                        core::WeakPtr<ResourceBase>>
+      StrongOrWeakResourcePtr;
+
   struct ResourceAndName {
-    ResourceBase* resource;
+    StrongOrWeakResourcePtr resource;
     std::unique_ptr<string> name;
-    core::RefCountPtr<ResourceBase> resource_owner;
 
     ResourceAndName();
-    ResourceAndName(ResourceBase* resource, std::string name,
-                    ResourceBase* resource_owner);
+    ResourceAndName(StrongOrWeakResourcePtr&& resource, std::string name);
     ResourceAndName(ResourceAndName&& other) noexcept;
     ~ResourceAndName();
 
     ResourceAndName& operator=(ResourceAndName&&) noexcept;
 
+    // Returns a strong reference to resource, or nullptr if the resource is
+    // no longer valid.
+    core::RefCountPtr<ResourceBase> GetResource() const;
+
    private:
     TF_DISALLOW_COPY_AND_ASSIGN(ResourceAndName);
   };
@@ -296,6 +304,12 @@
   Status DoDelete(const std::string& container, TypeIndex type,
                   const std::string& resource_name) TF_MUST_USE_RESULT;
 
+  // Pops the ResourceAndName entry. The entry is moved from the list to
+  // the output argument `resource_and_name`.
+  Status PopResourceAndName(
+      const std::string& container, uint64 type_hash_code,
+      const std::string& resource_name, const std::string& type_name,
+      ResourceAndName& resource_and_name) TF_MUST_USE_RESULT;
   // Inserts the type name for 'hash_code' into the hash_code to type name map.
   Status InsertDebugTypeName(uint64 hash_code, const std::string& type_name)
       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc
index a1ff23b..b6f3096 100644
--- a/tensorflow/core/framework/resource_mgr_test.cc
+++ b/tensorflow/core/framework/resource_mgr_test.cc
@@ -180,6 +180,24 @@
   EXPECT_TRUE(kitty->RefCountIsOne());
   EXPECT_EQ("R/kitty", Find<Resource>(rm, "foo", "bar"));
 
+  {
+    core::RefCountPtr<Resource> dog{new Resource("dog")};
+    TF_CHECK_OK(rm.CreateUnowned("foo", "bark", dog.get()));
+    EXPECT_EQ("R/dog", Find<Resource>(rm, "foo", "bark"));
+    EXPECT_EQ(1, dog->WeakRefCount());
+    {
+      ResourceMgr rm1;
+      TF_CHECK_OK(rm1.CreateUnowned("foo", "bark", dog.get()));
+      EXPECT_EQ("R/dog", Find<Resource>(rm1, "foo", "bark"));
+      EXPECT_EQ(2, dog->WeakRefCount());
+    }
+    // If manager goes out of scope, the resource loses the weak ref.
+    EXPECT_EQ(1, dog->WeakRefCount());
+  }
+  // If resource goes out of scope, the look up reports not found.
+  HasError(FindErr<Resource>(rm, "foo", "bark"), error::NOT_FOUND,
+           "Resource foo/bark");
+
   // Drop the whole container foo.
   TF_CHECK_OK(rm.Cleanup("foo"));
   HasError(FindErr<Resource>(rm, "foo", "bar"), error::NOT_FOUND,
diff --git a/tensorflow/core/lib/core/BUILD b/tensorflow/core/lib/core/BUILD
index d4f26ca..8d7c634 100644
--- a/tensorflow/core/lib/core/BUILD
+++ b/tensorflow/core/lib/core/BUILD
@@ -207,7 +207,6 @@
         "blocking_counter_test.cc",
         "coding_test.cc",
         "notification_test.cc",
-        "refcount_test.cc",
         "status_test.cc",
         "threadpool_test.cc",
     ],
diff --git a/tensorflow/core/lib/core/refcount_test.cc b/tensorflow/core/lib/core/refcount_test.cc
deleted file mode 100644
index e4957e2..0000000
--- a/tensorflow/core/lib/core/refcount_test.cc
+++ /dev/null
@@ -1,107 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/core/refcount.h"
-
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace core {
-namespace {
-
-static int constructed = 0;
-static int destroyed = 0;
-
-class MyRef : public RefCounted {
- public:
-  MyRef() { constructed++; }
-  ~MyRef() override { destroyed++; }
-};
-
-class RefTest : public ::testing::Test {
- public:
-  RefTest() {
-    constructed = 0;
-    destroyed = 0;
-  }
-};
-
-TEST_F(RefTest, New) {
-  MyRef* ref = new MyRef;
-  ASSERT_EQ(1, constructed);
-  ASSERT_EQ(0, destroyed);
-  ref->Unref();
-  ASSERT_EQ(1, constructed);
-  ASSERT_EQ(1, destroyed);
-}
-
-TEST_F(RefTest, RefUnref) {
-  MyRef* ref = new MyRef;
-  ASSERT_EQ(1, constructed);
-  ASSERT_EQ(0, destroyed);
-  ref->Ref();
-  ASSERT_EQ(0, destroyed);
-  ref->Unref();
-  ASSERT_EQ(0, destroyed);
-  ref->Unref();
-  ASSERT_EQ(1, destroyed);
-}
-
-TEST_F(RefTest, RefCountOne) {
-  MyRef* ref = new MyRef;
-  ASSERT_TRUE(ref->RefCountIsOne());
-  ref->Unref();
-}
-
-TEST_F(RefTest, RefCountNotOne) {
-  MyRef* ref = new MyRef;
-  ref->Ref();
-  ASSERT_FALSE(ref->RefCountIsOne());
-  ref->Unref();
-  ref->Unref();
-}
-
-TEST_F(RefTest, ConstRefUnref) {
-  const MyRef* cref = new MyRef;
-  ASSERT_EQ(1, constructed);
-  ASSERT_EQ(0, destroyed);
-  cref->Ref();
-  ASSERT_EQ(0, destroyed);
-  cref->Unref();
-  ASSERT_EQ(0, destroyed);
-  cref->Unref();
-  ASSERT_EQ(1, destroyed);
-}
-
-TEST_F(RefTest, ReturnOfUnref) {
-  MyRef* ref = new MyRef;
-  ref->Ref();
-  EXPECT_FALSE(ref->Unref());
-  EXPECT_TRUE(ref->Unref());
-}
-
-TEST_F(RefTest, ScopedUnref) {
-  { ScopedUnref unref(new MyRef); }
-  EXPECT_EQ(destroyed, 1);
-}
-
-TEST_F(RefTest, ScopedUnref_Nullptr) {
-  { ScopedUnref unref(nullptr); }
-  EXPECT_EQ(destroyed, 0);
-}
-
-}  // namespace
-}  // namespace core
-}  // namespace tensorflow
diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD
index 1d2ec6a..b7d0970 100644
--- a/tensorflow/core/platform/BUILD
+++ b/tensorflow/core/platform/BUILD
@@ -585,6 +585,22 @@
     hdrs = ["refcount.h"],
     deps = [
         ":logging",
+        ":mutex",
+        ":thread_annotations",
+    ],
+)
+
+tf_cc_test(
+    name = "refcount_test",
+    size = "small",
+    srcs = [
+        "refcount_test.cc",
+    ],
+    deps = [
+        ":env",
+        ":refcount",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
     ],
 )
 
diff --git a/tensorflow/core/platform/refcount.h b/tensorflow/core/platform/refcount.h
index 621f786..9f2cb2d 100644
--- a/tensorflow/core/platform/refcount.h
+++ b/tensorflow/core/platform/refcount.h
@@ -17,9 +17,12 @@
 #define TENSORFLOW_CORE_PLATFORM_REFCOUNT_H_
 
 #include <atomic>
+#include <map>
 #include <memory>
 
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
 
 namespace tensorflow {
 namespace core {
@@ -56,6 +59,12 @@
   // be instantiated directly. Only subclasses can be instantiated.
   virtual ~RefCounted();
 
+  // Increments reference count by one if the object is not being destructed.
+  // This function is used by WeakRefCounted for securely acquiring a
+  // strong reference. It is only safe to call this as part of the weak
+  // reference implementation.
+  bool TryRef() const;
+
  private:
   mutable std::atomic_int_fast32_t ref_;
 
@@ -65,7 +74,7 @@
 
 // A deleter class to form a std::unique_ptr that unrefs objects.
 struct RefCountDeleter {
-  void operator()(tensorflow::core::RefCounted* o) const { o->Unref(); }
+  void operator()(RefCounted* o) const { o->Unref(); }
 };
 
 // A unique_ptr that unrefs the owned object on destruction.
@@ -87,28 +96,129 @@
   void operator=(const ScopedUnref&) = delete;
 };
 
+// Forward declaration for friend class of WeakRefCounted.
+template <typename T>
+class WeakPtr;
+
+// A base class for RefCounted objects that allow weak references by WeakPtr.
+// WeakRefCounted and every WeakPtr to it, each holds a strong reference to a
+// WeakRefData.
+//
+// If the WeakRefCounted is valid, WeakPtr::GetNewRef() returns a new strong
+// reference to the WeakRefCounted.
+// If the WeakRefCounted is being destructed, `WeakRefCounted::ref_ == 0`;
+// if the WeakRefcounted is already destructed,`WeakRefData::ptr == nullptr`.
+// In either case, WeakPtr::GetNewRef() returns a nullptr.
+class WeakRefCounted : public RefCounted {
+ public:
+  int WeakRefCount() const {
+    // Each weak ref owns one ref to data_, and *this owns the last one.
+    return data_->RefCount() - 1;
+  }
+
+ protected:
+  ~WeakRefCounted() override { data_->Reset(); }
+
+ private:
+  struct WeakRefData : public RefCounted {
+    explicit WeakRefData(WeakRefCounted* ptr) : ptr(ptr) {}
+
+    mutable mutex mu;
+    WeakRefCounted* ptr TF_GUARDED_BY(mu);
+
+    void Reset() {
+      mutex_lock ml(mu);
+      ptr = nullptr;
+    }
+
+    WeakRefCounted* GetNewRef() {
+      mutex_lock ml(mu);
+      if (ptr != nullptr && ptr->TryRef()) {
+        return ptr;
+      }
+      return nullptr;
+    }
+  };
+
+  RefCountPtr<WeakRefData> data_{new WeakRefData(this)};
+
+  template <typename T>
+  friend class WeakPtr;
+};
+
+// A weak reference to a WeakRefCounted object. See WeakRefCounted.
+template <typename T>
+class WeakPtr {
+ public:
+  WeakPtr() : data_(nullptr) {}
+  // Creates a weak reference to a WeakRefCounted ptr.
+  // ptr must be valid during the constructor.
+  explicit WeakPtr(WeakRefCounted* ptr) : data_(nullptr) {
+    if (ptr != nullptr) {
+      ptr->data_->Ref();
+      data_.reset(ptr->data_.get());
+    }
+  }
+
+  // Returns a new strong reference to the referred object, or nullptr if the
+  // object is in an invalid state (being destructed or already destructed).
+  RefCountPtr<T> GetNewRef() const {
+    RefCountPtr<T> ref;
+    if (data_ != nullptr) {
+      WeakRefCounted* ptr = data_->GetNewRef();
+      ref.reset(static_cast<T*>(ptr));
+    }
+    return std::move(ref);
+  }
+
+ private:
+  // NOTE(feyu): change this to a IntrusivePtr to make WeakPtr copiable.
+  RefCountPtr<WeakRefCounted::WeakRefData> data_;
+};
+
 // Inlined routines, since these are performance critical
 inline RefCounted::RefCounted() : ref_(1) {}
 
-inline RefCounted::~RefCounted() { DCHECK_EQ(ref_.load(), 0); }
+inline RefCounted::~RefCounted() {
+  // A destructing object has ref_ == 0.
+  // It is a bug if the object is resurrected (ref_ > 0) before delete is
+  // called by Unref().
+  DCHECK_EQ(ref_.load(), 0);
+}
 
 inline void RefCounted::Ref() const {
-  DCHECK_GE(ref_.load(), 1);
-  ref_.fetch_add(1, std::memory_order_relaxed);
+  // Ref() uses relaxed order because it is never called with old_ref == 0.
+  // When old_ref >= 1, no actions depend on the new value of ref.
+  int_fast32_t old_ref = ref_.fetch_add(1, std::memory_order_relaxed);
+  DCHECK_GT(old_ref, 0);
+}
+
+inline bool RefCounted::TryRef() const {
+  // This is not on a hot path.
+  // Be conservative and use seq_cst to prevent racing with Unref() when
+  // old_ref == 0, as done in LLVM libstdc++.
+  int_fast32_t old_ref = ref_.load();
+  while (old_ref != 0) {
+    if (ref_.compare_exchange_weak(old_ref, old_ref + 1)) {
+      return true;
+    }
+  }
+  // Already destructing, cannot increase ref.
+  return false;
 }
 
 inline bool RefCounted::Unref() const {
   DCHECK_GT(ref_.load(), 0);
-  // If ref_==1, this object is owned only by the caller. Bypass a locked op
-  // in that case.
-  if (RefCountIsOne() || ref_.fetch_sub(1) == 1) {
-    // Make DCHECK in ~RefCounted happy
-    DCHECK((ref_.store(0), true));
+  // acq_rel is used to prevent reordering introduces object access after
+  // destruction.
+
+  // Using release alone is a bug on systems where acq_rel differs from release.
+  // (e.g. arm), according to Herb Sutter's 2012 talk on "Atomic<> Weapons".
+  if (ref_.fetch_sub(1, std::memory_order_acq_rel) == 1) {
     delete this;
     return true;
-  } else {
-    return false;
   }
+  return false;
 }
 
 inline int_fast32_t RefCounted::RefCount() const {
diff --git a/tensorflow/core/platform/refcount_test.cc b/tensorflow/core/platform/refcount_test.cc
new file mode 100644
index 0000000..fbd0605
--- /dev/null
+++ b/tensorflow/core/platform/refcount_test.cc
@@ -0,0 +1,165 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/refcount.h"
+
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/threadpool.h"
+
+namespace tensorflow {
+namespace core {
+namespace {
+
+class RefTest : public ::testing::Test {
+ public:
+  RefTest() {
+    constructed_ = 0;
+    destroyed_ = 0;
+  }
+
+  static int constructed_;
+  static int destroyed_;
+};
+
+int RefTest::constructed_;
+int RefTest::destroyed_;
+
+class MyRef : public RefCounted {
+ public:
+  MyRef() { RefTest::constructed_++; }
+  ~MyRef() override { RefTest::destroyed_++; }
+};
+
+TEST_F(RefTest, New) {
+  MyRef* ref = new MyRef;
+  ASSERT_EQ(1, constructed_);
+  ASSERT_EQ(0, destroyed_);
+  ref->Unref();
+  ASSERT_EQ(1, constructed_);
+  ASSERT_EQ(1, destroyed_);
+}
+
+TEST_F(RefTest, RefUnref) {
+  MyRef* ref = new MyRef;
+  ASSERT_EQ(1, constructed_);
+  ASSERT_EQ(0, destroyed_);
+  ref->Ref();
+  ASSERT_EQ(0, destroyed_);
+  ref->Unref();
+  ASSERT_EQ(0, destroyed_);
+  ref->Unref();
+  ASSERT_EQ(1, destroyed_);
+}
+
+TEST_F(RefTest, RefCountOne) {
+  MyRef* ref = new MyRef;
+  ASSERT_TRUE(ref->RefCountIsOne());
+  ref->Unref();
+}
+
+TEST_F(RefTest, RefCountNotOne) {
+  MyRef* ref = new MyRef;
+  ref->Ref();
+  ASSERT_FALSE(ref->RefCountIsOne());
+  ref->Unref();
+  ref->Unref();
+}
+
+TEST_F(RefTest, ConstRefUnref) {
+  const MyRef* cref = new MyRef;
+  ASSERT_EQ(1, constructed_);
+  ASSERT_EQ(0, destroyed_);
+  cref->Ref();
+  ASSERT_EQ(0, destroyed_);
+  cref->Unref();
+  ASSERT_EQ(0, destroyed_);
+  cref->Unref();
+  ASSERT_EQ(1, destroyed_);
+}
+
+TEST_F(RefTest, ReturnOfUnref) {
+  MyRef* ref = new MyRef;
+  ref->Ref();
+  EXPECT_FALSE(ref->Unref());
+  EXPECT_TRUE(ref->Unref());
+}
+
+TEST_F(RefTest, ScopedUnref) {
+  { ScopedUnref unref(new MyRef); }
+  EXPECT_EQ(destroyed_, 1);
+}
+
+TEST_F(RefTest, ScopedUnref_Nullptr) {
+  { ScopedUnref unref(nullptr); }
+  EXPECT_EQ(destroyed_, 0);
+}
+
+class ObjType : public WeakRefCounted {};
+
+TEST(WeakPtr, SingleThread) {
+  auto obj = new ObjType();
+  auto weakptr = WeakPtr<ObjType>(obj);
+
+  ASSERT_TRUE(obj->RefCountIsOne());
+  EXPECT_EQ(obj->WeakRefCount(), 1);
+  EXPECT_NE(weakptr.GetNewRef(), nullptr);
+
+  obj->Unref();
+  EXPECT_EQ(weakptr.GetNewRef(), nullptr);
+}
+
+TEST(WeakPtr, MultiThreadedWeakRef) {
+  // Exercise 100 times to make sure both branches of fn are hit.
+  std::atomic<int> hit_destructed{0};
+
+  auto env = Env::Default();
+
+  for (int i = 0; i < 100; i++) {
+    auto obj = new ObjType();
+    auto weakptr = WeakPtr<ObjType>(obj);
+
+    bool obj_destructed = false;
+    EXPECT_EQ(obj->WeakRefCount(), 1);
+
+    auto fn = [&]() {
+      auto ref = weakptr.GetNewRef();
+      if (ref != nullptr) {
+        EXPECT_EQ(ref.get(), obj);
+        EXPECT_EQ(ref->WeakRefCount(), 1);
+        EXPECT_GE(ref->RefCount(), 1);
+      } else {
+        hit_destructed++;
+        EXPECT_TRUE(obj_destructed);
+      }
+    };
+
+    auto t1 = env->StartThread(ThreadOptions{}, "thread-1", fn);
+    auto t2 = env->StartThread(ThreadOptions{}, "thread-2", fn);
+
+    env->SleepForMicroseconds(10);
+    obj_destructed = true;  // This shall run before weakref is purged.
+    obj->Unref();
+
+    delete t1;
+    delete t2;
+
+    EXPECT_EQ(weakptr.GetNewRef(), nullptr);
+  }
+  ASSERT_GT(hit_destructed, 0);
+  ASSERT_LT(hit_destructed, 200);  // 2 threads per iterations.
+}
+}  // namespace
+}  // namespace core
+}  // namespace tensorflow