Re-initialize device and type priority list in ProcessFLR when updating cluster.

PiperOrigin-RevId: 303420309
Change-Id: Ia9afe83873043f15e34312a93979726933c880d5
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index a8513b9..33221e5 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -1102,6 +1102,7 @@
     if (rendezvous_ != nullptr) rendezvous_->Unref();
     rendezvous_ = r;
     remote_eager_workers_ = std::move(remote_eager_workers);
+    pflr_->InitializeDeviceSet();
     InitPrioritizedDeviceTypeList();
 
     default_executor_.ClearError();
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 76c5f34..1543a34 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -110,14 +110,7 @@
         session_metadata_, this);
   }
 
-  DeviceMgr const* all_devices = device_mgr_;
-  if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
-    all_devices = parent_->remote_device_mgr();
-  }
-
-  for (auto d : all_devices->ListDevices()) {
-    device_set_.AddDevice(d);
-  }
+  InitializeDeviceSet();
 }
 
 /* static */
@@ -214,6 +207,18 @@
                           "function executions");
 }
 
+void ProcessFunctionLibraryRuntime::InitializeDeviceSet() {
+  DeviceMgr const* all_devices = device_mgr_;
+  if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
+    all_devices = parent_->remote_device_mgr();
+  }
+
+  device_set_.reset(new DeviceSet);
+  for (auto d : all_devices->ListDevices()) {
+    device_set_->AddDevice(d);
+  }
+}
+
 FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
     const string& device_name) const {
   Device* device = nullptr;
@@ -678,7 +683,7 @@
   TF_RETURN_IF_ERROR(
       SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
   TF_RETURN_IF_ERROR(PinArgsAndRets(
-      options.input_devices, options.output_devices, device_set_, arg_nodes,
+      options.input_devices, options.output_devices, *device_set_, arg_nodes,
       ret_nodes,
       options.config_proto.allow_soft_placement() ? default_device : nullptr));
 
@@ -691,7 +696,7 @@
 
   bool control_rets_updated = false;
   TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
-      device_set_, options.config_proto, &graph, &data->lib_def_,
+      *device_set_, options.config_proto, &graph, &data->lib_def_,
       &control_ret_node_names, &control_rets_updated));
 
   if (control_rets_updated) {
@@ -714,7 +719,7 @@
   optimization_options.session_options = &session_options;
   optimization_options.graph = &graph;
   optimization_options.flib_def = &data->lib_def_;
-  optimization_options.device_set = &device_set_;
+  optimization_options.device_set = device_set_.get();
   optimization_options.is_function_graph = true;
 
   DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
@@ -725,7 +730,7 @@
   // exceptions/warnings in case where nested function call options are ignored.
   DumpGraph("Before calling Placer", graph.get());
   Placer placer(graph.get(), function_name, optimization_options.flib_def,
-                &device_set_, default_device,
+                device_set_.get(), default_device,
                 options.config_proto.allow_soft_placement(),
                 options.config_proto.log_device_placement());
   TF_RETURN_IF_ERROR(placer.Run());
@@ -741,7 +746,7 @@
     DumpGraph("Before running graph optimization fn", graph.get());
     Status status = options.optimize_graph_fn(
         std::move(ret_node_names), std::move(control_ret_node_names),
-        &data->lib_def_, device_set_, cpu_device, &graph);
+        &data->lib_def_, *device_set_, cpu_device, &graph);
     if (!status.ok()) {
       LOG(WARNING) << "Ignoring multi-device function optimization failure: "
                    << status.ToString();
@@ -765,7 +770,7 @@
 
   std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
   TF_RETURN_IF_ERROR(
-      PartitionFunctionGraph(device_set_, std::move(graph), &subgraphs));
+      PartitionFunctionGraph(*device_set_, std::move(graph), &subgraphs));
 
   for (const auto& pair : subgraphs) {
     DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (",
@@ -841,7 +846,7 @@
       const string& target = pair.first;
 
       const string& device_type =
-          device_set_.FindDeviceByName(target)->device_type();
+          device_set_->FindDeviceByName(target)->device_type();
       Graph* subgraph = pair.second.get();
 
       status->Update(UpdateArgAndRetvalMetadata(
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 1d7708f..f8550fd 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -191,7 +191,10 @@
 
   const DeviceMgr* device_mgr() { return device_mgr_; }
 
-  const DeviceSet* device_set() { return &device_set_; }
+  const DeviceSet* device_set() { return device_set_.get(); }
+
+  // Initialize the set of local and remote devices for op device selection.
+  void InitializeDeviceSet();
 
   const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
 
@@ -422,7 +425,7 @@
   Env* const env_;
   const absl::optional<const ConfigProto> config_;
   const DeviceMgr* const device_mgr_;
-  DeviceSet device_set_;
+  std::unique_ptr<DeviceSet> device_set_;
   const FunctionLibraryDefinition* lib_def_;
   thread::ThreadPool* default_thread_pool_;