Internal change
PiperOrigin-RevId: 400907447
Change-Id: Ia64b5ccf0a50643abcb61a01ef15e9842ba02340
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 17e1c8f..58f640d 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -76,48 +76,6 @@
} // namespace
-// Find the rendezvous instance corresponding to the step id, or create a
-// new instance if not existing.
-Rendezvous* EagerContext::LocalRendezvousTable::FindOrCreate(
- int64_t step_id, DeviceMgr* device_mgr) {
- mutex_lock l(table_lock_);
- auto iter = table_.find(step_id);
- if (iter == table_.end()) {
- iter =
- table_.insert({step_id, new IntraProcessRendezvous(device_mgr)}).first;
- // Global rendezvous: ref-count should be 1 upon creation.
- if (step_id == -1) {
- return iter->second;
- }
- }
- iter->second->Ref();
- return iter->second;
-}
-
-void EagerContext::LocalRendezvousTable::Remove(int64_t step_id) {
- mutex_lock l(table_lock_);
- auto iter = table_.find(step_id);
- if (iter != table_.end()) {
- if (iter->second) {
- iter->second->StartAbort(errors::Aborted("Cleanup ", iter->first));
- }
- table_.erase(iter);
- }
-}
-
-void EagerContext::LocalRendezvousTable::CleanUpAll() {
- mutex_lock l(table_lock_);
- for (auto iter = table_.begin(); iter != table_.end(); iter++) {
- // Unref all redezvous instance, except for global rendezvous,
- // which is cleaned up elsewhere when necessary.
- if (iter->first == -1) {
- continue;
- }
- iter->second->StartAbort(errors::Aborted("Cleanup ", iter->first));
- iter->second->Unref();
- }
-}
-
EagerContext::EagerContext(
const SessionOptions& opts,
ContextDevicePlacementPolicy default_device_placement_policy, bool async,
@@ -172,11 +130,6 @@
opts.config, local_device_mgr(),
MaybeCreateNcclCommunicator(opts.config)));
}
-
- // Initialization of local_rendezvous_table_ needs to happen before the
- // initialization of global_rendezvous_for_functions_ because the latter
- // depends on the former.
- local_rendezvous_table_ = std::make_unique<LocalRendezvousTable>();
global_rendezvous_for_functions_ =
core::RefCountPtr<Rendezvous>(CreateRendezvous(-1));
}
@@ -971,10 +924,13 @@
sg.Update(s);
}
#endif // !IS_MOBILE_PLATFORM
-
- // Reset the global rendezvous, which otherwise stores a failure state.
- ResetGlobalRendezvousForFunction();
-
+ {
+ // Reset the global function rendezvous, which otherwise stores a failure
+ // state.
+ mutex_lock l(global_rendezvous_mu_);
+ global_rendezvous_for_functions_ =
+ core::RefCountPtr<Rendezvous>(CreateRendezvous(-1));
+ }
return sg.as_summary_status();
}
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 07bb5de..1d0e4e7 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -104,10 +104,7 @@
CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr,
bool run_eager_op_as_function = false);
- void Release() override {
- local_rendezvous_table_->CleanUpAll();
- Unref();
- }
+ void Release() override { Unref(); }
AbstractTensorInterface* CreateInt64Scalar(int64_t value) override;
AbstractTensorInterface* CreateUint64Scalar(uint64 value) override;
@@ -285,10 +282,6 @@
void ResetGlobalRendezvousForFunction() override {
mutex_lock l(global_rendezvous_mu_);
- // Remove the global rendezvous instance from the local rendezvous table
- // if it uses local rendezvous type, which forces EagerContext to create a
- // new local rendezvous instance in the table.
- local_rendezvous_table_->Remove(-1);
global_rendezvous_for_functions_ =
core::RefCountPtr<Rendezvous>(CreateRendezvous(-1));
}
@@ -533,21 +526,6 @@
DistributedFunctionLibraryRuntime* cluster_flr);
private:
- // The class for wrapping a map of step_id to local rendezvous instances.
- class LocalRendezvousTable {
- public:
- LocalRendezvousTable() = default;
- ~LocalRendezvousTable() = default;
-
- Rendezvous* FindOrCreate(int64_t step_id, DeviceMgr* device_mgr);
- void Remove(int64_t step_id);
- void CleanUpAll();
-
- private:
- mutable mutex table_lock_;
- absl::flat_hash_map<int64_t, Rendezvous*> table_ TF_GUARDED_BY(table_lock_);
- };
-
Rendezvous* CreateRendezvous(int64_t step_id) const {
if (rendezvous_creator_ != nullptr) {
return rendezvous_creator_(step_id);
@@ -562,7 +540,7 @@
#endif
if (remote_device_mgr() == nullptr) {
- return local_rendezvous_table_->FindOrCreate(step_id, local_device_mgr());
+ return new IntraProcessRendezvous(local_device_mgr());
}
return nullptr;
@@ -704,10 +682,6 @@
const bool log_memory_;
- // The table of local rendezvous instances for intra-process communication.
- // This make sures only one local rendezvous instance exists per step id.
- std::unique_ptr<LocalRendezvousTable> local_rendezvous_table_;
-
// Whether to use same rendezvous instance across function/eager executions.
std::atomic<bool> reuse_rendezvous_for_functions_{false};
mutable mutex global_rendezvous_mu_;
diff --git a/tensorflow/core/common_runtime/eager/context_test.cc b/tensorflow/core/common_runtime/eager/context_test.cc
index a3a8a6a..724dd00 100644
--- a/tensorflow/core/common_runtime/eager/context_test.cc
+++ b/tensorflow/core/common_runtime/eager/context_test.cc
@@ -303,77 +303,5 @@
retvals[0] = nullptr;
}
-TEST_F(EagerContextTest, LocalRendezvousCreation) {
- InitContext(SessionOptions(), DEVICE_PLACEMENT_EXPLICIT);
- std::function<Rendezvous*(const int64_t)> rendezvous_creator =
- context()->RendezvousCreator();
-
- // Create a new rendezvous instance.
- // Initially its ref-count is 2:
- // one added upopn rendezvous creation, the other one added by EagerContext.
- Rendezvous* rendezvous_1 = rendezvous_creator(1);
- EXPECT_EQ(rendezvous_1->RefCount(), 2);
-
- // Create another rendezvous instance with the same step-id.
- // This would add one more ref-count to the existing rendezvous insteance
- // insted of creating a new instance.
- Rendezvous* rendezvous_2 = rendezvous_creator(1);
- EXPECT_EQ(rendezvous_2->RefCount(), 3);
-
- // Caller releases rendezvous-1.
- rendezvous_1->Unref();
- EXPECT_EQ(rendezvous_1->RefCount(), 2);
-
- // Caller releases rendezvous-2.
- rendezvous_2->Unref();
- EXPECT_EQ(rendezvous_2->RefCount(), 1);
-
- // Final clean up for the unit test.
- // In normal cases, EagerContext::Release() is called to clean up the
- // remaining rendezvous instances. Unit tests have different cleanup mechanism
- // so explicitly calling Release() here would mess up the cleanup.
- rendezvous_1->Unref();
-}
-
-void TestGlobalRendezvous(EagerContext* context, bool reuse_global_rendezvous) {
- context->SetReuseRendezvousForFunctions(reuse_global_rendezvous);
- EXPECT_EQ(context->GetReuseRendezvousForFunctions(), reuse_global_rendezvous);
-
- auto rendezvous_creator = context->RendezvousCreator();
- Rendezvous* rendezvous_1 = rendezvous_creator(-1);
- EXPECT_EQ(rendezvous_1->RefCount(), 2);
- Rendezvous* rendezvous_2 = rendezvous_creator(-1);
- EXPECT_EQ(rendezvous_2->RefCount(), 3);
-
- // Global rendezvous's ref-count should be back to 1 after resetting.
- context->ResetGlobalRendezvousForFunction();
-
- Rendezvous* rendezvous_3 = rendezvous_creator(-1);
- EXPECT_EQ(rendezvous_3->RefCount(), 2);
-
- // Final clean up for the unit test.
- // In normal cases, EagerContext::Release() is called by TFE_DeleteContext()
- // in order to clean up the remaining rendezvous instances. The unit tests in
- // this file do not go through TFE_DeleteContext() so we need to explicitly
- // unref for the remaining rendezvous instance to prevent the tests from
- // memory leak.
- rendezvous_1->Unref();
- rendezvous_2->Unref();
- rendezvous_3->Unref();
-}
-
-TEST_F(EagerContextTest, GlobalRendezvousCreation) {
- InitContext(SessionOptions(), DEVICE_PLACEMENT_EXPLICIT);
-
- TestGlobalRendezvous(context(), false);
-}
-
-TEST_F(EagerContextTest, ReuseGlobalRendezvous) {
- InitContext(SessionOptions(), DEVICE_PLACEMENT_EXPLICIT);
- EXPECT_FALSE(context()->GetReuseRendezvousForFunctions());
-
- TestGlobalRendezvous(context(), true);
-}
-
} // namespace
} // namespace tensorflow