Re-initialize device and type priority list in ProcessFLR when updating cluster.
PiperOrigin-RevId: 303420309
Change-Id: Ia9afe83873043f15e34312a93979726933c880d5
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index a8513b9..33221e5 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -1102,6 +1102,7 @@
if (rendezvous_ != nullptr) rendezvous_->Unref();
rendezvous_ = r;
remote_eager_workers_ = std::move(remote_eager_workers);
+ pflr_->InitializeDeviceSet();
InitPrioritizedDeviceTypeList();
default_executor_.ClearError();
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 76c5f34..1543a34 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -110,14 +110,7 @@
session_metadata_, this);
}
- DeviceMgr const* all_devices = device_mgr_;
- if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
- all_devices = parent_->remote_device_mgr();
- }
-
- for (auto d : all_devices->ListDevices()) {
- device_set_.AddDevice(d);
- }
+ InitializeDeviceSet();
}
/* static */
@@ -214,6 +207,18 @@
"function executions");
}
+void ProcessFunctionLibraryRuntime::InitializeDeviceSet() {
+ DeviceMgr const* all_devices = device_mgr_;
+ if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
+ all_devices = parent_->remote_device_mgr();
+ }
+
+ device_set_.reset(new DeviceSet);
+ for (auto d : all_devices->ListDevices()) {
+ device_set_->AddDevice(d);
+ }
+}
+
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
const string& device_name) const {
Device* device = nullptr;
@@ -678,7 +683,7 @@
TF_RETURN_IF_ERROR(
SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
TF_RETURN_IF_ERROR(PinArgsAndRets(
- options.input_devices, options.output_devices, device_set_, arg_nodes,
+ options.input_devices, options.output_devices, *device_set_, arg_nodes,
ret_nodes,
options.config_proto.allow_soft_placement() ? default_device : nullptr));
@@ -691,7 +696,7 @@
bool control_rets_updated = false;
TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
- device_set_, options.config_proto, &graph, &data->lib_def_,
+ *device_set_, options.config_proto, &graph, &data->lib_def_,
&control_ret_node_names, &control_rets_updated));
if (control_rets_updated) {
@@ -714,7 +719,7 @@
optimization_options.session_options = &session_options;
optimization_options.graph = &graph;
optimization_options.flib_def = &data->lib_def_;
- optimization_options.device_set = &device_set_;
+ optimization_options.device_set = device_set_.get();
optimization_options.is_function_graph = true;
DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
@@ -725,7 +730,7 @@
// exceptions/warnings in case where nested function call options are ignored.
DumpGraph("Before calling Placer", graph.get());
Placer placer(graph.get(), function_name, optimization_options.flib_def,
- &device_set_, default_device,
+ device_set_.get(), default_device,
options.config_proto.allow_soft_placement(),
options.config_proto.log_device_placement());
TF_RETURN_IF_ERROR(placer.Run());
@@ -741,7 +746,7 @@
DumpGraph("Before running graph optimization fn", graph.get());
Status status = options.optimize_graph_fn(
std::move(ret_node_names), std::move(control_ret_node_names),
- &data->lib_def_, device_set_, cpu_device, &graph);
+ &data->lib_def_, *device_set_, cpu_device, &graph);
if (!status.ok()) {
LOG(WARNING) << "Ignoring multi-device function optimization failure: "
<< status.ToString();
@@ -765,7 +770,7 @@
std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
TF_RETURN_IF_ERROR(
- PartitionFunctionGraph(device_set_, std::move(graph), &subgraphs));
+ PartitionFunctionGraph(*device_set_, std::move(graph), &subgraphs));
for (const auto& pair : subgraphs) {
DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (",
@@ -841,7 +846,7 @@
const string& target = pair.first;
const string& device_type =
- device_set_.FindDeviceByName(target)->device_type();
+ device_set_->FindDeviceByName(target)->device_type();
Graph* subgraph = pair.second.get();
status->Update(UpdateArgAndRetvalMetadata(
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 1d7708f..f8550fd 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -191,7 +191,10 @@
const DeviceMgr* device_mgr() { return device_mgr_; }
- const DeviceSet* device_set() { return &device_set_; }
+ const DeviceSet* device_set() { return device_set_.get(); }
+
+ // Initialize the set of local and remote devices for op device selection.
+ void InitializeDeviceSet();
const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
@@ -422,7 +425,7 @@
Env* const env_;
const absl::optional<const ConfigProto> config_;
const DeviceMgr* const device_mgr_;
- DeviceSet device_set_;
+ std::unique_ptr<DeviceSet> device_set_;
const FunctionLibraryDefinition* lib_def_;
thread::ThreadPool* default_thread_pool_;