Dependency optimizer groups cross-host control edges to reduce RPC traffic.
PiperOrigin-RevId: 279344543
Change-Id: Iff504d1be03b02afaf9a945fc5b19b3a7ac57b5a
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index ace1308..80e6bcf 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -362,6 +362,8 @@
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
"//tensorflow/core/grappler/utils:grappler_test",
"//tensorflow/core/grappler/utils:topological_sort",
+ "//tensorflow/core/platform:test",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
index 0734c32..fed6003 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
@@ -19,6 +19,7 @@
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
@@ -613,11 +614,18 @@
// We can reduce cross-device communication by introducing an intermediate
// NoOp node C' on device X and rewriting the control edges to:
// A->C', B->C', C' -> C
-void DependencyOptimizer::GroupCrossDeviceControlEdges() {
+void DependencyOptimizer::GroupCrossDeviceControlEdges(bool host_granularity) {
+ VLOG(1)
+ << "DependencyOptimizer::GroupCrossDeviceControlEdges host_granularity="
+ << host_granularity;
const int num_nodes = optimized_graph_->node_size();
for (int i = 0; i < num_nodes; ++i) {
NodeDef* node = optimized_graph_->mutable_node(i);
if (node->device().empty()) continue;
+ string rest, node_device = node->device();
+ if (host_granularity) {
+ DeviceNameUtils::SplitDeviceName(node->device(), &node_device, &rest);
+ }
// Creates new noop nodes for devices on which multiple control inputs are
// located.
@@ -630,11 +638,19 @@
for (int j = 0; j < node->input_size(); ++j) {
if (IsControlInput(node->input(j))) {
const NodeDef* input = node_map_->GetNode(node->input(j));
- if (input != nullptr && !input->device().empty() &&
- input->device() != node->device()) {
- auto emplace_result = noops.emplace(input->device(), nullptr);
+ if (input == nullptr || input->device().empty()) continue;
+ string input_device = input->device();
+ if (host_granularity) {
+ DeviceNameUtils::SplitDeviceName(input->device(), &input_device,
+ &rest);
+ }
+ if (input_device != node_device) {
+ VLOG(2) << "Cross-device " << node->name() << " " << input->device()
+ << " -> " << node->device();
+ auto emplace_result = noops.emplace(input_device, nullptr);
if (!emplace_result.second &&
emplace_result.first->second == nullptr) {
+ VLOG(2) << "Duplicate input device from " << node->name();
// This is the second cross-device control input from the same
// device. Creates an intermediate noop node on that device.
string group_name;
@@ -654,6 +670,8 @@
noop->set_op("NoOp");
node_map_->AddNode(noop->name(), noop);
emplace_result.first->second = noop;
+ VLOG(1) << "GroupCrossDeviceControlEdges: Added "
+ << SummarizeNodeDef(*noop);
}
}
}
@@ -668,10 +686,16 @@
if (input == nullptr) {
++pos;
} else {
- auto it = noops.find(input->device());
+ string input_device = input->device();
+ if (host_granularity) {
+ DeviceNameUtils::SplitDeviceName(input->device(), &input_device,
+ &rest);
+ }
+ auto it = noops.find(input_device);
if (it == noops.end() || it->second == nullptr) {
++pos;
} else {
+ VLOG(2) << "Rewriting input from " << input_name;
node->mutable_input()->SwapElements(pos, node->input_size() - 1);
node->mutable_input()->RemoveLast();
it->second->add_input(AsControlDependency(*input));
@@ -725,7 +749,11 @@
// Dedup control inputs.
CleanControlInputs();
- GroupCrossDeviceControlEdges();
+ // Merge multiple control edges from the same device.
+ GroupCrossDeviceControlEdges(/*host_granularity=*/false);
+
+ // Merge control edges from the same host to reduce RPC traffic.
+ GroupCrossDeviceControlEdges(/*host_granularity=*/true);
}
return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/tensorflow/core/grappler/optimizers/dependency_optimizer.h
index feeddc2..bb34b25 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.h
@@ -71,8 +71,9 @@
// Main driver of dependency optimizations.
Status OptimizeDependencies();
// Replaces multiple cross-device control edges from the same device with a
- // single control edge.
- void GroupCrossDeviceControlEdges();
+ // single control edge. If `host_granularity` is true then group control
+ // edges from all devices on the same host.
+ void GroupCrossDeviceControlEdges(bool host_granularity);
bool fetch_nodes_known_;
std::unordered_set<string> nodes_to_preserve_;
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
index 074b0dc..8c8107c 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
@@ -14,6 +14,8 @@
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
+
+#include "absl/strings/match.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@@ -888,6 +890,61 @@
CompareGraphs(expected, output);
}
+TEST_F(DependencyOptimizerTest, GroupCrossHostControlDeps) {
+ GrapplerItem item;
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ std::vector<Operation> ops;
+ Output a = ops::RandomUniform(s.WithOpName("a").WithDevice("/CPU:0"),
+ {1, 2}, DT_FLOAT);
+ for (int t = 0; t < 4; ++t) {
+ for (int c = 0; c < 8; ++c) {
+ string opname = absl::StrCat("t", t, "/c", c);
+ string device = absl::StrCat("/task:", t, "/device:TPU:", c);
+ Output output = ops::RandomUniform(
+ s.WithOpName(opname).WithDevice(device), {1, 2}, DT_FLOAT);
+ ops.push_back(output.op());
+ }
+ }
+ // Node with cross-device dependencies.
+ auto fetch = ops::Identity(
+ s.WithOpName("f").WithControlDependencies(ops).WithDevice("/CPU:0"),
+ {a});
+
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch.push_back("f");
+ }
+
+ GraphDef expected;
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.ToGraphDef(&expected));
+ }
+
+ DependencyOptimizer optimizer;
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(output.node_size(), item.graph.node_size() + 4);
+ std::set<string> tasks;
+ for (const auto& n : output.node()) {
+ if (n.op() == "NoOp") {
+ EXPECT_TRUE(absl::StartsWith(n.name(), "GroupCrossDeviceControlEdges"));
+ EXPECT_EQ(n.input_size(), 8);
+ tasks.insert(n.device());
+ }
+
+ if (n.name() == "f") {
+ EXPECT_EQ(n.input_size(), 5);
+ for (const auto& i : n.input()) {
+ EXPECT_TRUE(i == "a" ||
+ absl::StartsWith(i, "^GroupCrossDeviceControlEdges"));
+ }
+ }
+ }
+ EXPECT_EQ(tasks.size(), 4);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow