| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/grappler/costs/virtual_scheduler.h" |
| |
| #include "tensorflow/cc/ops/standard_ops.h" |
| #include "tensorflow/core/framework/tensor_description.pb.h" |
| #include "tensorflow/core/framework/tensor_shape.pb.h" |
| #include "tensorflow/core/grappler/clusters/virtual_cluster.h" |
| #include "tensorflow/core/grappler/costs/utils.h" |
| #include "tensorflow/core/grappler/costs/virtual_placer.h" |
| #include "tensorflow/core/lib/core/status_test_util.h" |
| #include "tensorflow/core/platform/test.h" |
| |
| namespace tensorflow { |
| namespace grappler { |
| namespace { |
| |
| // Device names: |
| constexpr char kCPU0[] = "/job:localhost/replica:0/task:0/cpu:0"; |
| constexpr char kCPU1[] = "/job:localhost/replica:0/task:0/cpu:1"; |
| constexpr char kChannelFrom0To1[] = "Channel from CPU0 to CPU1"; |
| constexpr char kChannelFrom1To0[] = "Channel from CPU1 to CPU0"; |
| // Op names: |
| constexpr char kConv2D[] = "Conv2D"; |
| constexpr char kSend[] = "_Send"; |
| constexpr char kRecv[] = "_Recv"; |
| |
| class ReadyNodeManagerTest : public ::testing::Test { |
| protected: |
| ReadyNodeManagerTest() { |
| // node1_ to node6_ on kCPU0, with time_ready in reverse_order. |
| NodeSetUp("Node1", kConv2D, kCPU0, 6000, &node1_); |
| NodeSetUp("Node2", kConv2D, kCPU0, 5000, &node2_); |
| NodeSetUp("Node3", kConv2D, kCPU0, 4000, &node3_); |
| NodeSetUp("Node4", kConv2D, kCPU0, 3000, &node4_); |
| NodeSetUp("Node5", kConv2D, kCPU0, 2000, &node5_); |
| NodeSetUp("Node6", kConv2D, kCPU0, 1000, &node6_); |
| } |
| |
| void NodeSetUp(const string& name, const string& op_name, |
| const string& device_name, const uint64 time_ready, |
| NodeDef* node) { |
| node->set_name(name); |
| node->set_op(op_name); |
| node->set_device(device_name); |
| |
| node_states_[node] = NodeState(); |
| node_states_[node].time_ready = time_ready; |
| node_states_[node].device_name = device_name; |
| } |
| |
| NodeDef node1_, node2_, node3_, node4_, node5_, node6_; |
| std::unordered_map<const NodeDef*, NodeState> node_states_; |
| }; |
| |
| // Tests that FIFOManager correctly returns the current node with only 1 node. |
| TEST_F(ReadyNodeManagerTest, GetSingleNodeFIFOManager) { |
| FIFOManager manager = FIFOManager(); |
| manager.AddNode(&node1_); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node1"); |
| } |
| |
| // Tests that FIFOManager removes the only node contained within. |
| TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFIFOManager) { |
| FIFOManager manager = FIFOManager(); |
| manager.AddNode(&node1_); |
| |
| // Removes the only node in FIFOManager. |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| } |
| |
| // Tests that FIFOManager can remove multiple nodes and returns the current node |
| // in the right order. |
| TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFIFOManager) { |
| FIFOManager manager = FIFOManager(); |
| manager.AddNode(&node1_); |
| manager.AddNode(&node2_); |
| manager.AddNode(&node3_); |
| manager.AddNode(&node4_); |
| |
| // Keeps checking current node while removing nodes from manager. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node1"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node2"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node3"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node4"); |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| } |
| |
| // Tests that FIFOManager can remove multiple nodes and add more nodes, still |
| // returning the current node in the right order. |
| TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleFIFOManager) { |
| FIFOManager manager = FIFOManager(); |
| manager.AddNode(&node1_); |
| manager.AddNode(&node2_); |
| manager.AddNode(&node3_); |
| manager.AddNode(&node4_); |
| |
| // Keeps checking current node as nodes are removed and added. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node1"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node2"); |
| manager.AddNode(&node5_); |
| // GetCurrNode() should return the same node even if some nodes are added, |
| // until RemoveCurrNode() is called. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node2"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node3"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node4"); |
| manager.AddNode(&node6_); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node4"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node5"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node6"); |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| } |
| |
| // Tests that LIFOManager correctly returns the current node with only 1 node. |
| TEST_F(ReadyNodeManagerTest, GetSingleNodeLIFOManager) { |
| LIFOManager manager = LIFOManager(); |
| manager.AddNode(&node1_); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node1"); |
| } |
| |
| // Tests that LIFOManager removes the only node contained within. |
| TEST_F(ReadyNodeManagerTest, RemoveSingleNodeLIFOManager) { |
| LIFOManager manager = LIFOManager(); |
| manager.AddNode(&node1_); |
| |
| // Removes the only node in LIFOManager. |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| } |
| |
| // Tests that LIFOManager can remove multiple nodes and returns the current node |
| // in the right order. |
| TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleLIFOManager) { |
| LIFOManager manager = LIFOManager(); |
| manager.AddNode(&node1_); |
| manager.AddNode(&node2_); |
| manager.AddNode(&node3_); |
| manager.AddNode(&node4_); |
| |
| // Keeps checking current node while removing nodes from manager. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node4"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node3"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node2"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node1"); |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| } |
| |
| // Tests that LIFOManager can remove multiple nodes (must be removing the |
| // current node) and add more nodes, still returning the current node in the |
| // right order. |
| TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleLIFOManager) { |
| LIFOManager manager = LIFOManager(); |
| manager.AddNode(&node1_); |
| manager.AddNode(&node2_); |
| manager.AddNode(&node3_); |
| manager.AddNode(&node4_); |
| |
| // Keeps checking current node as nodes are removed and added. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node4"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node3"); |
| manager.AddNode(&node5_); |
| // GetCurrNode() should return the same node even if some nodes are added, |
| // until RemoveCurrNode() is called. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node3"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node5"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node2"); |
| manager.AddNode(&node6_); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node2"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node6"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node1"); |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| } |
| |
| TEST_F(ReadyNodeManagerTest, GetSingleNodeFirstReadyManager) { |
| FirstReadyManager manager; |
| TF_EXPECT_OK(manager.Init(&node_states_)); |
| manager.AddNode(&node1_); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node1"); |
| } |
| |
| TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFirstReadyManager) { |
| FirstReadyManager manager; |
| TF_EXPECT_OK(manager.Init(&node_states_)); |
| manager.AddNode(&node1_); |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| } |
| |
| TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFirstReadyManager) { |
| FirstReadyManager manager; |
| TF_EXPECT_OK(manager.Init(&node_states_)); |
| // Insert nodes in some random order. |
| manager.AddNode(&node2_); |
| manager.AddNode(&node1_); |
| manager.AddNode(&node4_); |
| manager.AddNode(&node5_); |
| manager.AddNode(&node3_); |
| manager.AddNode(&node6_); |
| |
| // In whatever order we insert nodes, we get the same order based on nodes' |
| // time_ready. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node6"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node5"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node4"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node3"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node2"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node1"); |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| } |
| |
| TEST_F(ReadyNodeManagerTest, GetCurrNodeFirstReadyManager) { |
| FirstReadyManager manager; |
| TF_EXPECT_OK(manager.Init(&node_states_)); |
| |
| // Inserts nodes in some random order. |
| manager.AddNode(&node2_); |
| manager.AddNode(&node1_); |
| manager.AddNode(&node4_); |
| manager.AddNode(&node5_); |
| manager.AddNode(&node3_); |
| manager.AddNode(&node6_); |
| |
| // Among these nodes, node6 has the smallest time_ready, hence, GetCurrNode() |
| // should return it. |
| EXPECT_EQ("Node6", manager.GetCurrNode()->name()); |
| |
| // Now inserts a few other nodes, but their time_ready's are even smaller than |
| // that of Node6. Before calling RemoveCurrNode(), GetCurrNode() should return |
| // the same node, Node6, in this case. |
| NodeDef node7; |
| NodeDef node8; |
| NodeDef node9; |
| NodeSetUp("Node7", kConv2D, kCPU0, 5, &node7); |
| NodeSetUp("Node8", kConv2D, kCPU0, 4, &node8); |
| NodeSetUp("Node9", kConv2D, kCPU0, 3, &node9); |
| |
| manager.AddNode(&node7); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node6"); |
| |
| manager.AddNode(&node8); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node6"); |
| |
| manager.RemoveCurrNode(); |
| // Now Node6 is removed, and GetCurrNode() will return Node8. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node8"); |
| |
| // Again, AddNode shouldn't change GetCurrNode(). |
| manager.AddNode(&node9); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node8"); |
| |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node9"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node7"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node5"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node4"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node3"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node2"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node1"); |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| } |
| |
| TEST_F(ReadyNodeManagerTest, DeterminismInFirstReadyManager) { |
| FirstReadyManager manager1; |
| TF_EXPECT_OK(manager1.Init(&node_states_)); |
| FirstReadyManager manager2; |
| TF_EXPECT_OK(manager2.Init(&node_states_)); |
| |
| // 6 nodes with same time_ready. |
| NodeDef node7; |
| NodeDef node8; |
| NodeDef node9; |
| NodeDef node10; |
| NodeDef node11; |
| NodeDef node12; |
| NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7); |
| NodeSetUp("Node8", kConv2D, kCPU0, 1000, &node8); |
| NodeSetUp("Node9", kConv2D, kCPU0, 1000, &node9); |
| NodeSetUp("Node10", kConv2D, kCPU0, 1000, &node10); |
| NodeSetUp("Node11", kConv2D, kCPU0, 1000, &node11); |
| NodeSetUp("Node12", kConv2D, kCPU0, 1000, &node12); |
| |
| // Adds the above 6 nodes to manager1. |
| manager1.AddNode(&node7); |
| manager1.AddNode(&node8); |
| manager1.AddNode(&node9); |
| manager1.AddNode(&node10); |
| manager1.AddNode(&node11); |
| manager1.AddNode(&node12); |
| |
| // Adds the above 6 nodes to manager2, but in a different order. |
| manager2.AddNode(&node8); |
| manager2.AddNode(&node11); |
| manager2.AddNode(&node9); |
| manager2.AddNode(&node10); |
| manager2.AddNode(&node7); |
| manager2.AddNode(&node12); |
| |
| // Expects both managers return the same nodes for deterministic node |
| // scheduling. |
| EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name()); |
| manager1.RemoveCurrNode(); |
| manager2.RemoveCurrNode(); |
| |
| EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name()); |
| manager1.RemoveCurrNode(); |
| manager2.RemoveCurrNode(); |
| |
| EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name()); |
| manager1.RemoveCurrNode(); |
| manager2.RemoveCurrNode(); |
| |
| EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name()); |
| manager1.RemoveCurrNode(); |
| manager2.RemoveCurrNode(); |
| |
| EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name()); |
| manager1.RemoveCurrNode(); |
| manager2.RemoveCurrNode(); |
| |
| EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name()); |
| manager1.RemoveCurrNode(); |
| manager2.RemoveCurrNode(); |
| |
| EXPECT_TRUE(manager1.Empty()); |
| EXPECT_TRUE(manager2.Empty()); |
| } |
| |
| TEST_F(ReadyNodeManagerTest, GetAndRemoveMultiplePriorityReadyManager) { |
| PriorityReadyManager manager; |
| TF_EXPECT_OK(manager.Init(&node_states_)); |
| |
| // Sets up node priorities. |
| std::unordered_map<string, int> node_priority = { |
| {"Node1", 1}, {"Node2", 2}, {"Node3", 2}, {"Node4", 4}, {"Node5", 5}}; |
| TF_EXPECT_OK(manager.SetPriority(node_priority)); |
| |
| // Inserts nodes in some random order. |
| manager.AddNode(&node3_); |
| manager.AddNode(&node1_); |
| manager.AddNode(&node4_); |
| manager.AddNode(&node5_); |
| manager.AddNode(&node2_); |
| manager.AddNode(&node6_); |
| |
| // Expects nodes scheduled based on priority. |
| // Node6 should default to lowest priority, since it is not found. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node6"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node1"); |
| manager.RemoveCurrNode(); |
| // Nodes 2 and 3 have equal priority and so should be scheduled ready-first. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node3"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node2"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node4"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node5"); |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| } |
| |
| TEST_F(ReadyNodeManagerTest, RemoveSingleNodeCompositeNodeManager) { |
| CompositeNodeManager manager; |
| TF_EXPECT_OK(manager.Init(&node_states_)); |
| manager.AddNode(&node1_); |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| } |
| |
| TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleCompositeNodeManager) { |
| CompositeNodeManager manager; |
| TF_EXPECT_OK(manager.Init(&node_states_)); |
| manager.AddNode(&node1_); |
| manager.AddNode(&node2_); |
| manager.AddNode(&node3_); |
| manager.AddNode(&node4_); |
| |
| // Keeps checking current node as nodes are removed and added. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node4"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node3"); |
| manager.AddNode(&node5_); |
| // GetCurrNode() should return the same node even if some nodes are added, |
| // until RemoveCurrNode() is called. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node3"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node5"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node2"); |
| manager.AddNode(&node6_); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node2"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node6"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node1"); |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| } |
| |
| TEST_F(ReadyNodeManagerTest, MultiDeviceSendRecvCompositeNodeManager) { |
| CompositeNodeManager manager; |
| TF_EXPECT_OK(manager.Init(&node_states_)); |
| // Additional nodes on kCPU1. |
| NodeDef node7; |
| NodeDef node8; |
| NodeDef node9; |
| NodeSetUp("Node7", kConv2D, kCPU1, 1001, &node7); |
| NodeSetUp("Node8", kConv2D, kCPU1, 2001, &node8); |
| NodeSetUp("Node9", kConv2D, kCPU1, 3001, &node9); |
| |
| // Send and Recv nodes. |
| NodeDef send1; |
| NodeDef send2; |
| NodeDef recv1; |
| NodeDef recv2; |
| NodeSetUp("Send1", kSend, kChannelFrom0To1, 2002, &send1); |
| NodeSetUp("Send2", kSend, kChannelFrom1To0, 2005, &send2); |
| NodeSetUp("Recv1", kRecv, kCPU0, 2003, &recv1); |
| NodeSetUp("Recv2", kRecv, kCPU1, 2004, &recv2); |
| |
| // Inserts nodes. |
| manager.AddNode(&node1_); |
| manager.AddNode(&node2_); |
| manager.AddNode(&node3_); |
| manager.AddNode(&node4_); |
| manager.AddNode(&node5_); |
| manager.AddNode(&node6_); |
| manager.AddNode(&node7); |
| manager.AddNode(&node8); |
| manager.AddNode(&node9); |
| manager.AddNode(&send1); |
| manager.AddNode(&send2); |
| manager.AddNode(&recv1); |
| manager.AddNode(&recv2); |
| |
| // On kCPU0; last one is node6_, on kCPU1: last one is node9; |
| // so choose one that has earliest time_ready among node6_, node9, |
| // Send1, Send2, Recv1, and Recv2. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node6"); |
| manager.RemoveCurrNode(); |
| // Then, the next one on kCPU0 is node5_; choose the earliest time_ready node |
| // among node5_, node9, Send1, Send2, Recv1, and Recv2. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node5"); |
| manager.RemoveCurrNode(); |
| // Next, choose among node4_, node9, Send1, Send2, Recv1, and Recv2. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Send1"); |
| manager.RemoveCurrNode(); |
| // Next, choose among node4_, node9, Sen2, Recv1, and Recv2. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Recv1"); |
| manager.RemoveCurrNode(); |
| // Next, choose among node4_, node9, Send2, and Recv2. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Recv2"); |
| manager.RemoveCurrNode(); |
| // Next, choose among node4_, node9, and Send2. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Send2"); |
| manager.RemoveCurrNode(); |
| // Next, choose between node4_, node9. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node4"); |
| manager.RemoveCurrNode(); |
| // Next, choose between node3_, node9. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node9"); |
| manager.RemoveCurrNode(); |
| // Next, choose between node3_, node8. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node8"); |
| manager.RemoveCurrNode(); |
| // Next, choose between node3_, node7. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node7"); |
| manager.RemoveCurrNode(); |
| // Then, just the nodes on kCPU1 -- LIFO. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node3"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node2"); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node1"); |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| } |
| |
| TEST_F(ReadyNodeManagerTest, DeterminismInCompositeNodeManager) { |
| CompositeNodeManager manager; |
| TF_EXPECT_OK(manager.Init(&node_states_)); |
| CompositeNodeManager manager2; |
| TF_EXPECT_OK(manager2.Init(&node_states_)); |
| |
| // 6 nodes with same time_ready. |
| NodeDef node7; |
| NodeDef node8; |
| NodeDef node9; |
| NodeDef node10; |
| NodeDef node11; |
| NodeDef node12; |
| NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7); |
| NodeSetUp("Node8", kSend, kCPU0, 1000, &node8); |
| NodeSetUp("Node9", kRecv, kCPU0, 1000, &node9); |
| NodeSetUp("Node10", kConv2D, kCPU0, 999, &node10); |
| NodeSetUp("Node11", kRecv, kCPU0, 999, &node11); |
| NodeSetUp("Node12", kConv2D, kCPU1, 1000, &node12); |
| |
| // Adds Nodes 7 to 9 to manager. |
| manager.AddNode(&node7); |
| manager.AddNode(&node8); |
| manager.AddNode(&node9); |
| |
| // It should return _Send, Recv, and the other op order, when the candidate |
| // nodes have same time_ready. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node8"); |
| EXPECT_EQ(manager.GetCurrNode()->op(), kSend); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node9"); |
| EXPECT_EQ(manager.GetCurrNode()->op(), kRecv); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node7"); |
| EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D); |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| |
| // Adds Nodes 7 to 9 to manager, but in a different order. |
| manager.AddNode(&node9); |
| manager.AddNode(&node8); |
| manager.AddNode(&node7); |
| |
| // Expects same order (_Send, _Recv, and the other op), regardless of Add |
| // order. |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node8"); |
| EXPECT_EQ(manager.GetCurrNode()->op(), kSend); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node9"); |
| EXPECT_EQ(manager.GetCurrNode()->op(), kRecv); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node7"); |
| EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D); |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| |
| // Conv2D's time_ready < Send's time_ready; Expects Conv2D first. |
| manager.AddNode(&node8); |
| manager.AddNode(&node10); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node10"); |
| EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node8"); |
| EXPECT_EQ(manager.GetCurrNode()->op(), kSend); |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| |
| // Recv's time_ready < Send' time_ready; Expects Recv first. |
| manager.AddNode(&node11); |
| manager.AddNode(&node8); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node11"); |
| EXPECT_EQ(manager.GetCurrNode()->op(), kRecv); |
| manager.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), "Node8"); |
| EXPECT_EQ(manager.GetCurrNode()->op(), kSend); |
| manager.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| |
| // Node7 and 12 are normal ops with the same time_ready, placed on different |
| // devices. These two nodes are added to manager and manager2, but in |
| // different orders; Expects GetCurrNode() returns the nodes in the same |
| // order. |
| manager.AddNode(&node7); |
| manager.AddNode(&node12); |
| |
| manager2.AddNode(&node12); |
| manager2.AddNode(&node7); |
| |
| EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name()); |
| manager.RemoveCurrNode(); |
| manager2.RemoveCurrNode(); |
| EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name()); |
| manager.RemoveCurrNode(); |
| manager2.RemoveCurrNode(); |
| EXPECT_TRUE(manager.Empty()); |
| } |
| |
| // Class for testing virtual scheduler. |
| class TestVirtualScheduler : public VirtualScheduler { |
| public: |
| TestVirtualScheduler(const bool use_static_shapes, |
| const bool use_aggressive_shape_inference, |
| ReadyNodeManager* ready_node_manager, Cluster* cluster) |
| : VirtualScheduler( |
| use_static_shapes, use_aggressive_shape_inference, cluster, |
| ready_node_manager, |
| absl::make_unique<VirtualPlacer>(cluster->GetDevices())) { |
| enable_mem_usage_tracking(); |
| } |
| |
| FRIEND_TEST(VirtualSchedulerTest, MemoryUsage); |
| FRIEND_TEST(VirtualSchedulerTest, ControlDependency); |
| FRIEND_TEST(VirtualSchedulerTest, ComplexDependency); |
| FRIEND_TEST(VirtualSchedulerTest, Variable); |
| FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer); |
| }; |
| |
| class VirtualSchedulerTest : public ::testing::Test { |
| protected: |
| VirtualSchedulerTest() { |
| // Initializes cluster_ and scheduler_. |
| std::unordered_map<string, DeviceProperties> devices; |
| |
| // Set some dummy CPU properties |
| DeviceProperties cpu_device = GetDummyCPUDevice(); |
| |
| // IMPORTANT: Device is not actually ever used in the test case since |
| // force_cpu_type is defaulted to "Haswell" |
| devices[kCPU0] = cpu_device; |
| devices[kCPU1] = cpu_device; |
| cluster_ = absl::make_unique<VirtualCluster>(devices); |
| scheduler_ = absl::make_unique<TestVirtualScheduler>( |
| /*use_static_shapes=*/true, |
| /*use_aggressive_shape_inference=*/true, &first_ready_manager_, |
| cluster_.get()); |
| } |
| |
| DeviceProperties GetDummyCPUDevice() { |
| // Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth. |
| // - 8 Gflops |
| // - 2 GB/s |
| DeviceProperties cpu_device; |
| cpu_device.set_type("CPU"); |
| cpu_device.set_frequency(4000); |
| cpu_device.set_num_cores(2); |
| cpu_device.set_bandwidth(2000000); |
| return cpu_device; |
| } |
| |
| // Three Conv2Ds with only two in fetch nodes. |
| void CreateGrapplerItemWithConv2Ds() { |
| Scope s = Scope::NewRootScope().WithDevice(kCPU0); |
| auto x = ops::RandomUniform( |
| s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); |
| auto y = ops::RandomUniform( |
| s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); |
| auto z = ops::RandomUniform( |
| s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); |
| auto f = ops::RandomUniform( |
| s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT); |
| std::vector<int> strides = {1, 1, 1, 1}; |
| auto c0 = ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME"); |
| auto c1 = ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME"); |
| auto c2 = ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME"); |
| |
| grappler_item_ = absl::make_unique<GrapplerItem>(); |
| TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph)); |
| grappler_item_->id = "test_conv2d_graph"; |
| grappler_item_->fetch = {"c0", "c1"}; |
| |
| dependency_["c0"] = {"x", "f"}; |
| dependency_["c1"] = {"y", "f"}; |
| } |
| |
| // A Conv2D with a variable. |
| void CreateGrapplerItemWithConv2DAndVariable() { |
| Scope s = Scope::NewRootScope().WithDevice(kCPU0); |
| auto x = ops::RandomUniform( |
| s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); |
| auto f = ops::Variable(s.WithOpName("f"), |
| {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT); |
| std::vector<int> strides = {1, 1, 1, 1}; |
| auto y = ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME"); |
| |
| grappler_item_ = absl::make_unique<GrapplerItem>(); |
| TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph)); |
| grappler_item_->id = "test_conv2d_var_graph"; |
| |
| grappler_item_->fetch = {"y"}; |
| |
| dependency_["y"] = {"x", "f"}; |
| } |
| |
| void CreateGrapplerItemWithMatmulChain() { |
| Scope s = Scope::NewRootScope().WithDevice(kCPU0); |
| // Add control dependencies to ensure tests do not rely on specific |
| // manager and the order remains consistent for the test. |
| auto a = ops::RandomUniform(s.WithOpName("a"), {3200, 3200}, DT_FLOAT); |
| auto b = ops::RandomUniform(s.WithOpName("b").WithControlDependencies(a), |
| {3200, 3200}, DT_FLOAT); |
| auto c = ops::RandomUniform(s.WithOpName("c").WithControlDependencies(b), |
| {3200, 3200}, DT_FLOAT); |
| auto d = ops::RandomUniform(s.WithOpName("d").WithControlDependencies(c), |
| {3200, 3200}, DT_FLOAT); |
| auto e = ops::RandomUniform(s.WithOpName("e").WithControlDependencies(d), |
| {3200, 3200}, DT_FLOAT); |
| |
| auto ab = ops::MatMul(s.WithOpName("ab").WithControlDependencies(e), a, b); |
| auto abc = ops::MatMul(s.WithOpName("abc"), ab, c); |
| auto abcd = ops::MatMul(s.WithOpName("abcd"), abc, d); |
| auto abcde = ops::MatMul(s.WithOpName("abcde"), abcd, e); |
| |
| grappler_item_ = absl::make_unique<GrapplerItem>(); |
| TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph)); |
| grappler_item_->id = "test_matmul_sequence_graph"; |
| grappler_item_->fetch = {"abcde"}; |
| |
| dependency_["ab"] = {"a", "b"}; |
| dependency_["abc"] = {"ab", "c"}; |
| dependency_["abcd"] = {"abc", "d"}; |
| dependency_["abcde"] = {"abcd", "e"}; |
| } |
| |
| // AddN that takes 4 tensors with 10x10x10x10. |
| void CreateGrapplerItemWithAddN() { |
| Scope s = Scope::NewRootScope().WithDevice(kCPU0); |
| auto x = ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10}, DT_FLOAT); |
| auto y = ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10}, DT_FLOAT); |
| auto z = ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10}, DT_FLOAT); |
| auto w = ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10}, DT_FLOAT); |
| OutputList input_tensors = {x, y, z, w}; |
| auto out = ops::AddN(s.WithOpName("out"), input_tensors); |
| |
| grappler_item_ = absl::make_unique<GrapplerItem>(); |
| TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph)); |
| grappler_item_->id = "test_addn_graph"; |
| grappler_item_->fetch = {"out"}; |
| |
| dependency_["out"] = {"x", "y", "z", "w"}; |
| } |
| |
| // Graph with some placeholder feed nodes that are not in the fetch fan-in. |
| void CreateGrapplerItemWithUnnecessaryPlaceholderNodes() { |
| Scope s = Scope::NewRootScope().WithDevice(kCPU0); |
| auto unnecessary = ops::Placeholder(s.WithOpName("unnecessary"), DT_FLOAT); |
| auto x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT); |
| |
| grappler_item_ = absl::make_unique<GrapplerItem>(); |
| TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph)); |
| |
| grappler_item_->id = "test_extra_placeholders"; |
| grappler_item_->fetch = {"x"}; |
| |
| // Grappler Item Builder puts all placeholder nodes into the feed |
| // list by default. |
| grappler_item_->feed = {{"x", Tensor()}, {"unnecessary", Tensor()}}; |
| } |
| |
| // NoOp that takes 7 NoOps as control dependency. |
| void CreateGrapplerItemWithControlDependency() { |
| Scope s = Scope::NewRootScope().WithDevice(kCPU0); |
| std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"}; |
| std::vector<Operation> input_tensors; |
| for (const auto& input : input_noop_names) { |
| auto x = ops::NoOp(s.WithOpName(input)); |
| input_tensors.push_back(x.operation); |
| } |
| auto out = |
| ops::NoOp(s.WithControlDependencies(input_tensors).WithOpName("out")); |
| |
| grappler_item_ = absl::make_unique<GrapplerItem>(); |
| TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph)); |
| |
| grappler_item_->id = "test_control_dependency_graph"; |
| grappler_item_->fetch = {"out"}; |
| |
| dependency_["out"] = input_noop_names; |
| } |
| |
| void CreateGrapplerItemWithAddFromOneTensor() { |
| Scope s = Scope::NewRootScope().WithDevice(kCPU0); |
| auto x = tensorflow::ops::RandomUniform( |
| s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); |
| |
| auto y = tensorflow::ops::Add(s.WithOpName("y"), x, x); |
| Output fetch = ops::Identity(s.WithOpName("fetch"), y); |
| |
| grappler_item_ = absl::make_unique<GrapplerItem>(); |
| TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph)); |
| |
| grappler_item_->id = "test_add_from_one_tensor"; |
| grappler_item_->fetch = {"fetch"}; |
| |
| dependency_["fetch"] = {"y"}; |
| dependency_["y"] = {"x"}; |
| } |
| |
| void CreateGrapplerItemWithSwitchMergeInput() { |
| // sw = Switch(x, pred) |
| // a = Add(S:1, b) |
| // m = Merge(sw:0, a) |
| // y = Add(m, z) |
| |
| Scope s = Scope::NewRootScope().WithDevice(kCPU0); |
| auto x = ops::RandomUniform( |
| s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); |
| auto pred = ops::Const(s.WithOpName("pred"), false, {}); |
| auto sw = ops::Switch(s.WithOpName("switch"), x, pred); |
| auto b = ops::RandomUniform( |
| s.WithOpName("b"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); |
| auto a = ops::Add(s.WithOpName("a"), sw.output_true, b); |
| auto m = ops::Merge(s.WithOpName("m"), {sw.output_false, a.z}); |
| auto z = ops::RandomUniform( |
| s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); |
| auto y = ops::Add(s.WithOpName("y"), m.output, z); |
| |
| grappler_item_ = absl::make_unique<GrapplerItem>(); |
| TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph)); |
| |
| grappler_item_->id = "test_add_merge_switch"; |
| grappler_item_->fetch = {"y"}; |
| |
| dependency_["y"] = {"m", "z"}; |
| } |
| |
| // FusedBN [an op with multiple outputs] with multiple consumers (including |
| // control dependency). |
| void CreateGrapplerItemWithBatchNorm() { |
| Scope s = Scope::NewRootScope().WithDevice(kCPU0); |
| auto x = ops::RandomUniform( |
| s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); |
| auto scale = |
| ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT); |
| auto offset = |
| ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT); |
| auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT); |
| auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT); |
| |
| auto batch_norm = ops::FusedBatchNorm( |
| s.WithOpName("bn"), x, scale, offset, mean, var, |
| ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f)); |
| auto y = batch_norm.y; |
| auto batch_mean = batch_norm.batch_mean; |
| auto batch_var = batch_norm.batch_variance; |
| |
| auto z1 = ops::Add(s.WithOpName("z1"), x, y); |
| auto z2 = ops::Add(s.WithOpName("z2"), batch_var, batch_var); |
| auto z3 = ops::Add(s.WithOpName("z3"), batch_var, batch_var); |
| std::vector<Operation> input_tensors = { |
| batch_mean.op(), |
| z1.z.op(), |
| z2.z.op(), |
| z3.z.op(), |
| }; |
| auto z4 = ops::NoOp(s.WithControlDependencies(batch_var).WithOpName("z4")); |
| |
| grappler_item_ = absl::make_unique<GrapplerItem>(); |
| TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph)); |
| |
| grappler_item_->id = "test_complex_dependency_graph"; |
| grappler_item_->fetch = {"z1", "z2", "z3", "z4"}; |
| |
| dependency_["bn"] = {"x", "scale", "offset", "mean", "var"}; |
| dependency_["z1"] = {"x", "bn"}; |
| dependency_["z2"] = {"bn"}; |
| dependency_["z3"] = {"bn"}; |
| dependency_["z4"] = {"bn"}; |
| } |
| |
| void CreateGrapplerItemWithSendRecv() { |
| const string gdef_ascii = R"EOF( |
| node { |
| name: "Const" |
| op: "Const" |
| device: "/job:localhost/replica:0/task:0/device:CPU:0" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_FLOAT |
| tensor_shape { |
| } |
| float_val: 3.1415 |
| } |
| } |
| } |
| } |
| node { |
| name: "Send" |
| op: "_Send" |
| input: "Const" |
| device: "/job:localhost/replica:0/task:0/device:CPU:0" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "client_terminated" |
| value { |
| b: false |
| } |
| } |
| attr { |
| key: "recv_device" |
| value { |
| s: "/job:localhost/replica:0/task:0/device:CPU:0" |
| } |
| } |
| attr { |
| key: "send_device" |
| value { |
| s: "/job:localhost/replica:0/task:0/device:CPU:0" |
| } |
| } |
| attr { |
| key: "send_device_incarnation" |
| value { |
| i: 0 |
| } |
| } |
| attr { |
| key: "tensor_name" |
| value { |
| s: "test" |
| } |
| } |
| } |
| node { |
| name: "Recv" |
| op: "_Recv" |
| device: "/job:localhost/replica:0/task:0/device:CPU:0" |
| attr { |
| key: "client_terminated" |
| value { |
| b: false |
| } |
| } |
| attr { |
| key: "recv_device" |
| value { |
| s: "/job:localhost/replica:0/task:0/device:CPU:0" |
| } |
| } |
| attr { |
| key: "send_device" |
| value { |
| s: "/job:localhost/replica:0/task:0/device:CPU:0" |
| } |
| } |
| attr { |
| key: "send_device_incarnation" |
| value { |
| i: 0 |
| } |
| } |
| attr { |
| key: "tensor_name" |
| value { |
| s: "test" |
| } |
| } |
| attr { |
| key: "tensor_type" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| } |
| library { |
| } |
| versions { |
| producer: 24 |
| } |
| )EOF"; |
| |
| grappler_item_ = absl::make_unique<GrapplerItem>(); |
| |
| CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, |
| &grappler_item_->graph)); |
| grappler_item_->id = "test_graph"; |
| grappler_item_->fetch = {"Recv"}; |
| } |
| |
| void CreateGrapplerItemWithRecvWithoutSend() { |
| const string gdef_ascii = R"EOF( |
| node { |
| name: "Recv" |
| op: "_Recv" |
| device: "/job:localhost/replica:0/task:0/device:CPU:0" |
| attr { |
| key: "client_terminated" |
| value { |
| b: false |
| } |
| } |
| attr { |
| key: "recv_device" |
| value { |
| s: "/job:localhost/replica:0/task:0/device:CPU:0" |
| } |
| } |
| attr { |
| key: "send_device" |
| value { |
| s: "/job:localhost/replica:0/task:0/device:CPU:0" |
| } |
| } |
| attr { |
| key: "send_device_incarnation" |
| value { |
| i: 0 |
| } |
| } |
| attr { |
| key: "tensor_name" |
| value { |
| s: "test" |
| } |
| } |
| attr { |
| key: "tensor_type" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| } |
| library { |
| } |
| versions { |
| producer: 24 |
| } |
| )EOF"; |
| |
| grappler_item_ = absl::make_unique<GrapplerItem>(); |
| CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, |
| &grappler_item_->graph)); |
| grappler_item_->id = "test_graph"; |
| grappler_item_->fetch = {"Recv"}; |
| } |
| |
| // A simple while loop |
| void CreateGrapplerItemWithLoop() { |
| // Test graph produced in python using: |
| /* |
| with tf.Graph().as_default(): |
| i0 = tf.constant(0) |
| m0 = tf.ones([2, 2]) |
| c = lambda i, m: i < 10 |
| b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] |
| r = tf.while_loop( |
| c, b, loop_vars=[i0, m0], |
| shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) |
| with open('/tmp/graph.pbtxt', 'w') as f: |
| f.write(str(tf.get_default_graph().as_graph_def())) |
| */ |
| const string gdef_ascii = R"EOF( |
| node { |
| name: "Const" |
| op: "Const" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_INT32 |
| tensor_shape { |
| } |
| int_val: 0 |
| } |
| } |
| } |
| } |
| node { |
| name: "ones" |
| op: "Const" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_FLOAT |
| tensor_shape { |
| dim { |
| size: 2 |
| } |
| dim { |
| size: 2 |
| } |
| } |
| float_val: 1.0 |
| } |
| } |
| } |
| } |
| node { |
| name: "while/Enter" |
| op: "Enter" |
| input: "Const" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "frame_name" |
| value { |
| s: "while/while/" |
| } |
| } |
| attr { |
| key: "is_constant" |
| value { |
| b: false |
| } |
| } |
| attr { |
| key: "parallel_iterations" |
| value { |
| i: 10 |
| } |
| } |
| } |
| node { |
| name: "while/Enter_1" |
| op: "Enter" |
| input: "ones" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "frame_name" |
| value { |
| s: "while/while/" |
| } |
| } |
| attr { |
| key: "is_constant" |
| value { |
| b: false |
| } |
| } |
| attr { |
| key: "parallel_iterations" |
| value { |
| i: 10 |
| } |
| } |
| } |
| node { |
| name: "while/Merge" |
| op: "Merge" |
| input: "while/Enter" |
| input: "while/NextIteration" |
| attr { |
| key: "N" |
| value { |
| i: 2 |
| } |
| } |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| } |
| node { |
| name: "while/Merge_1" |
| op: "Merge" |
| input: "while/Enter_1" |
| input: "while/NextIteration_1" |
| attr { |
| key: "N" |
| value { |
| i: 2 |
| } |
| } |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| } |
| node { |
| name: "while/Less/y" |
| op: "Const" |
| input: "^while/Merge" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_INT32 |
| tensor_shape { |
| } |
| int_val: 10 |
| } |
| } |
| } |
| } |
| node { |
| name: "while/Less" |
| op: "Less" |
| input: "while/Merge" |
| input: "while/Less/y" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| } |
| node { |
| name: "while/LoopCond" |
| op: "LoopCond" |
| input: "while/Less" |
| } |
| node { |
| name: "while/Switch" |
| op: "Switch" |
| input: "while/Merge" |
| input: "while/LoopCond" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "_class" |
| value { |
| list { |
| s: "loc:@while/Merge" |
| } |
| } |
| } |
| } |
| node { |
| name: "while/Switch_1" |
| op: "Switch" |
| input: "while/Merge_1" |
| input: "while/LoopCond" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "_class" |
| value { |
| list { |
| s: "loc:@while/Merge_1" |
| } |
| } |
| } |
| } |
| node { |
| name: "while/Identity" |
| op: "Identity" |
| input: "while/Switch:1" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| } |
| node { |
| name: "while/Identity_1" |
| op: "Identity" |
| input: "while/Switch_1:1" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| } |
| node { |
| name: "while/add/y" |
| op: "Const" |
| input: "^while/Identity" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_INT32 |
| tensor_shape { |
| } |
| int_val: 1 |
| } |
| } |
| } |
| } |
| node { |
| name: "while/add" |
| op: "Add" |
| input: "while/Identity" |
| input: "while/add/y" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| } |
| node { |
| name: "while/concat/axis" |
| op: "Const" |
| input: "^while/Identity" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_INT32 |
| tensor_shape { |
| } |
| int_val: 0 |
| } |
| } |
| } |
| } |
| node { |
| name: "while/concat" |
| op: "ConcatV2" |
| input: "while/Identity_1" |
| input: "while/Identity_1" |
| input: "while/concat/axis" |
| attr { |
| key: "N" |
| value { |
| i: 2 |
| } |
| } |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "Tidx" |
| value { |
| type: DT_INT32 |
| } |
| } |
| } |
| node { |
| name: "while/NextIteration" |
| op: "NextIteration" |
| input: "while/add" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| } |
| node { |
| name: "while/NextIteration_1" |
| op: "NextIteration" |
| input: "while/concat" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| } |
| node { |
| name: "while/Exit" |
| op: "Exit" |
| input: "while/Switch" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| } |
| node { |
| name: "while/Exit_1" |
| op: "Exit" |
| input: "while/Switch_1" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| } |
| versions { |
| producer: 21 |
| } |
| )EOF"; |
| |
| grappler_item_ = absl::make_unique<GrapplerItem>(); |
| CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, |
| &grappler_item_->graph)); |
| grappler_item_->id = "test_graph"; |
| grappler_item_->fetch = {"while/Exit", "while/Exit_1"}; |
| } |
| |
| // A simple while loop strengthened with Switch outputs xxx. |
| void CreateGrapplerItemWithLoopAnnotated() { |
| // Test graph produced in python using: |
| /* |
| with tf.Graph().as_default(): |
| i0 = tf.constant(0) |
| m0 = tf.ones([2, 2]) |
| c = lambda i, m: i < 10 |
| b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] |
| r = tf.while_loop( |
| c, b, loop_vars=[i0, m0], |
| shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) |
| with open('/tmp/graph.pbtxt', 'w') as f: |
| f.write(str(tf.get_default_graph().as_graph_def())) |
| */ |
| const string gdef_ascii = R"EOF( |
| node { |
| name: "Const" |
| op: "Const" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_INT32 |
| tensor_shape { |
| } |
| int_val: 0 |
| } |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 1 |
| } |
| } |
| } |
| node { |
| name: "ones" |
| op: "Const" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_FLOAT |
| tensor_shape { |
| dim { |
| size: 2 |
| } |
| dim { |
| size: 2 |
| } |
| } |
| float_val: 1.0 |
| } |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 1 |
| } |
| } |
| } |
| node { |
| name: "while/Enter" |
| op: "Enter" |
| input: "Const" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "frame_name" |
| value { |
| s: "while/while/" |
| } |
| } |
| attr { |
| key: "is_constant" |
| value { |
| b: false |
| } |
| } |
| attr { |
| key: "parallel_iterations" |
| value { |
| i: 10 |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 1 |
| } |
| } |
| } |
| node { |
| name: "while/Enter_1" |
| op: "Enter" |
| input: "ones" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "frame_name" |
| value { |
| s: "while/while/" |
| } |
| } |
| attr { |
| key: "is_constant" |
| value { |
| b: false |
| } |
| } |
| attr { |
| key: "parallel_iterations" |
| value { |
| i: 10 |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 1 |
| } |
| } |
| } |
| node { |
| name: "while/Merge" |
| op: "Merge" |
| input: "while/Enter" |
| input: "while/NextIteration" |
| attr { |
| key: "N" |
| value { |
| i: 2 |
| } |
| } |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 10 |
| } |
| } |
| } |
| node { |
| name: "while/Merge_1" |
| op: "Merge" |
| input: "while/Enter_1" |
| input: "while/NextIteration_1" |
| attr { |
| key: "N" |
| value { |
| i: 2 |
| } |
| } |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 10 |
| } |
| } |
| } |
| node { |
| name: "while/Less/y" |
| op: "Const" |
| input: "^while/Merge" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_INT32 |
| tensor_shape { |
| } |
| int_val: 10 |
| } |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 10 |
| } |
| } |
| } |
| node { |
| name: "while/Less" |
| op: "Less" |
| input: "while/Merge" |
| input: "while/Less/y" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 10 |
| } |
| } |
| } |
| node { |
| name: "while/LoopCond" |
| op: "LoopCond" |
| input: "while/Less" |
| attr { |
| key: "_execution_count" |
| value { |
| i: 10 |
| } |
| } |
| } |
| node { |
| name: "while/Switch" |
| op: "Switch" |
| input: "while/Merge" |
| input: "while/LoopCond" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "_class" |
| value { |
| list { |
| s: "loc:@while/Merge" |
| } |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 11 |
| } |
| } |
| attr { |
| key: "_output_slot_vector" |
| value { |
| list { |
| i: 1 |
| i: 1 |
| i: 1 |
| i: 1 |
| i: 1 |
| i: 1 |
| i: 1 |
| i: 1 |
| i: 1 |
| i: 1 |
| i: 0 |
| } |
| } |
| } |
| } |
| node { |
| name: "while/Switch_1" |
| op: "Switch" |
| input: "while/Merge_1" |
| input: "while/LoopCond" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "_class" |
| value { |
| list { |
| s: "loc:@while/Merge_1" |
| } |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 11 |
| } |
| } |
| attr { |
| key: "_output_slot_vector" |
| value { |
| list { |
| i: 1 |
| i: 1 |
| i: 1 |
| i: 1 |
| i: 1 |
| i: 1 |
| i: 1 |
| i: 1 |
| i: 1 |
| i: 1 |
| i: 0 |
| } |
| } |
| } |
| } |
| node { |
| name: "while/Identity" |
| op: "Identity" |
| input: "while/Switch:1" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 10 |
| } |
| } |
| } |
| node { |
| name: "while/Identity_1" |
| op: "Identity" |
| input: "while/Switch_1:1" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 10 |
| } |
| } |
| } |
| node { |
| name: "while/add/y" |
| op: "Const" |
| input: "^while/Identity" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_INT32 |
| tensor_shape { |
| } |
| int_val: 1 |
| } |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 10 |
| } |
| } |
| } |
| node { |
| name: "while/add" |
| op: "Add" |
| input: "while/Identity" |
| input: "while/add/y" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 10 |
| } |
| } |
| } |
| node { |
| name: "while/concat/axis" |
| op: "Const" |
| input: "^while/Identity" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_INT32 |
| tensor_shape { |
| } |
| int_val: 0 |
| } |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 10 |
| } |
| } |
| } |
| node { |
| name: "while/concat" |
| op: "ConcatV2" |
| input: "while/Identity_1" |
| input: "while/Identity_1" |
| input: "while/concat/axis" |
| attr { |
| key: "N" |
| value { |
| i: 2 |
| } |
| } |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "Tidx" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 10 |
| } |
| } |
| } |
| node { |
| name: "while/NextIteration" |
| op: "NextIteration" |
| input: "while/add" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 10 |
| } |
| } |
| } |
| node { |
| name: "while/NextIteration_1" |
| op: "NextIteration" |
| input: "while/concat" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 10 |
| } |
| } |
| } |
| node { |
| name: "while/Exit" |
| op: "Exit" |
| input: "while/Switch" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 1 |
| } |
| } |
| } |
| node { |
| name: "while/Exit_1" |
| op: "Exit" |
| input: "while/Switch_1" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "_execution_count" |
| value { |
| i: 1 |
| } |
| } |
| } |
| versions { |
| producer: 21 |
| } |
| )EOF"; |
| |
| grappler_item_.reset(new GrapplerItem); |
| CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, |
| &grappler_item_->graph)); |
| grappler_item_->id = "test_graph"; |
| grappler_item_->fetch = {"while/Exit", "while/Exit_1"}; |
| } |
| |
| // A simple condition graph. |
| void CreateGrapplerItemWithCondition() { |
| // Handcrafted test graph: a/Less -> Switch -> First/Second -> Merge. |
| const string gdef_ascii = R"EOF( |
| node { |
| name: "a" |
| op: "Const" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_FLOAT |
| tensor_shape { |
| } |
| float_val: 2.0 |
| } |
| } |
| } |
| } |
| node { |
| name: "Less" |
| op: "Const" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_BOOL |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_BOOL |
| tensor_shape { |
| } |
| tensor_content: "\001" |
| } |
| } |
| } |
| } |
| node { |
| name: "Switch" |
| op: "Switch" |
| input: "a" |
| input: "Less" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| } |
| node { |
| name: "First" |
| op: "Identity" |
| input: "Switch" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| } |
| node { |
| name: "Second" |
| op: "Identity" |
| input: "Switch:1" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| } |
| node { |
| name: "Merge" |
| op: "Merge" |
| input: "First" |
| input: "Second" |
| attr { |
| key: "N" |
| value { |
| i: 2 |
| } |
| } |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| } |
| versions { |
| producer: 27 |
| })EOF"; |
| |
| grappler_item_ = absl::make_unique<GrapplerItem>(); |
| CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, |
| &grappler_item_->graph)); |
| grappler_item_->id = "test_graph"; |
| grappler_item_->fetch = {"Merge"}; |
| } |
| |
| // Create a FusedBatchNorm op that has multiple output ports. |
| void CreateGrapplerItemWithInterDeviceTransfers() { |
| tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); |
| |
| // Create a FusedBatchNorm op that has multiple output ports. |
| auto x = ops::RandomUniform( |
| s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); |
| auto scale = |
| ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT); |
| auto offset = |
| ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT); |
| auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT); |
| auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT); |
| |
| auto batch_norm = ops::FusedBatchNorm( |
| s.WithOpName("bn"), x, scale, offset, mean, var, |
| ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f)); |
| auto y = batch_norm.y; |
| auto batch_mean = batch_norm.batch_mean; |
| auto batch_var = batch_norm.batch_variance; |
| // y1 and y2 take the same tensor, so there should be only 1 Send and Recv. |
| auto y1 = ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y); |
| auto y2 = ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y); |
| // batch_mean1 and batch_var1 take different output ports, so each will |
| // initiate Send/Recv. |
| auto batch_mean1 = ops::Identity( |
| s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean); |
| auto batch_var1 = |
| ops::Identity(s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var); |
| // This is control dependency. |
| auto control_dep = ops::NoOp(s.WithOpName("control_dep") |
| .WithControlDependencies(y) |
| .WithDevice(kCPU1)); |
| |
| grappler_item_ = absl::make_unique<GrapplerItem>(); |
| TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph)); |
| grappler_item_->id = "test_conv2d_graph"; |
| grappler_item_->fetch = {"y1", "y2", "batch_mean1", "batch_var1", |
| "control_dep"}; |
| |
| dependency_["bn"] = {"x", "mean", "var"}; |
| dependency_["y1"] = {"bn"}; |
| dependency_["y2"] = {"bn"}; |
| dependency_["batch_mean1"] = {"bn"}; |
| dependency_["batch_var1"] = {"bn"}; |
| dependency_["control_dep"] = {"bn"}; |
| } |
| |
| // Call this after creating grappler_item_ and setting up dependency_. |
| void InitScheduler() { TF_ASSERT_OK(scheduler_->Init(grappler_item_.get())); } |
| |
| // Returns cost based on op. |
| Costs SimplePredictCosts(const OpContext& op_context) const { |
| Costs c; |
| int64 exec_cost = 0; |
| if (op_context.op_info.op() == "MatMul") { |
| exec_cost = 2000000000; |
| } else if (op_context.op_info.op() == "RandomUniform") { |
| exec_cost = 1000000000; |
| } else { |
| exec_cost = 1000; |
| } |
| c.execution_time = Costs::NanoSeconds(exec_cost); |
| return c; |
| } |
| |
| // Call this after init scheduler_. Scheduler stops after executing |
| // target_node. |
| std::unordered_map<string, OpContext> RunScheduler( |
| const string& target_node) { |
| std::unordered_map<string, OpContext> ops_executed; |
| bool more_nodes = true; |
| do { |
| OpContext op_context = scheduler_->GetCurrNode(); |
| ops_executed[op_context.name] = op_context; |
| std::cout << op_context.name << std::endl; |
| |
| Costs node_costs = SimplePredictCosts(op_context); |
| |
| // Check scheduling order. |
| auto it = dependency_.find(op_context.name); |
| if (it != dependency_.end()) { |
| for (const auto& preceding_node : it->second) { |
| EXPECT_GT(ops_executed.count(preceding_node), 0); |
| } |
| } |
| more_nodes = scheduler_->MarkCurrNodeExecuted(node_costs); |
| |
| if (op_context.name == target_node) { |
| // Scheduler has the state after executing the target node. |
| break; |
| } |
| } while (more_nodes); |
| return ops_executed; |
| } |
| |
| // Helper method for validating a vector. |
| template <typename T> |
| void ExpectVectorEq(const std::vector<T>& expected, |
| const std::vector<T>& test_elements) { |
| // Set of expected elements for an easy comparison. |
| std::set<T> expected_set(expected.begin(), expected.end()); |
| for (const auto& element : test_elements) { |
| EXPECT_GT(expected_set.count(element), 0); |
| } |
| EXPECT_EQ(expected.size(), test_elements.size()); |
| } |
| |
| // Helper method that checks the name of nodes. |
| void ValidateNodeDefs(const std::vector<string>& expected, |
| const std::vector<const NodeDef*>& node_defs) { |
| std::vector<string> node_names; |
| std::transform(node_defs.begin(), node_defs.end(), |
| std::back_inserter(node_names), |
| [](const NodeDef* node) { return node->name(); }); |
| ExpectVectorEq(expected, node_names); |
| } |
| |
| // Helper method for validating a set. |
| template <typename T> |
| void ExpectSetEq(const std::set<T>& expected, |
| const std::set<T>& test_elements) { |
| for (const auto& element : test_elements) { |
| EXPECT_GT(expected.count(element), 0); |
| } |
| EXPECT_EQ(expected.size(), test_elements.size()); |
| } |
| |
| // Helper method for validating an unordered map. |
| template <typename T, typename U> |
| void ExpectUnorderedMapEq(const std::unordered_map<T, U>& expected, |
| const std::unordered_map<T, U>& test_map) { |
| EXPECT_EQ(expected.size(), test_map.size()); |
| for (const auto& key_val : expected) { |
| EXPECT_GT(test_map.count(key_val.first), 0); |
| EXPECT_EQ(test_map.at(key_val.first), key_val.second); |
| } |
| } |
| |
| // Helper method that checks name - port pairs. |
| void ValidateMemoryUsageSnapshot( |
| const std::vector<string>& expected_names, const int port_num_expected, |
| const std::unordered_set<std::pair<const NodeDef*, int>, |
| DeviceState::NodePairHash>& mem_usage_snapshot) { |
| std::set<std::pair<string, int>> nodes_at_peak_mem_usage; |
| std::transform( |
| mem_usage_snapshot.begin(), mem_usage_snapshot.end(), |
| std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()), |
| [](const std::pair<const NodeDef*, int>& node_port) { |
| return std::make_pair(node_port.first->name(), node_port.second); |
| }); |
| std::set<std::pair<string, int>> expected; |
| std::transform(expected_names.begin(), expected_names.end(), |
| std::inserter(expected, expected.begin()), |
| [port_num_expected](const string& name) { |
| return std::make_pair(name, port_num_expected); |
| }); |
| ExpectSetEq(expected, nodes_at_peak_mem_usage); |
| } |
| |
| // Helper method for checking nodes dependency. |
| void ValidateDependencyChain( |
| const std::unordered_map<string, int64>& start_times, |
| const std::vector<string>& nodes_in_dependency_order) { |
| int64 prev_node_time = -1; |
| for (const auto& node : nodes_in_dependency_order) { |
| int64 curr_node_time = start_times.at(node); |
| EXPECT_GE(curr_node_time, prev_node_time); |
| prev_node_time = curr_node_time; |
| } |
| } |
| |
| // cluster_ and scheduler_ are initialized in the c'tor. |
| std::unique_ptr<VirtualCluster> cluster_; |
| std::unique_ptr<TestVirtualScheduler> scheduler_; |
| FirstReadyManager first_ready_manager_; |
| CompositeNodeManager composite_node_manager_; |
| |
| // grappler_item_ will be initialized differently for each test case. |
| std::unique_ptr<GrapplerItem> grappler_item_; |
| // Node name -> its preceding nodes map for testing scheduling order. |
| std::unordered_map<string, std::vector<string>> dependency_; |
| |
| // Shared params for Conv2D related graphs: |
| const int batch_size_ = 4; |
| const int width_ = 10; |
| const int height_ = 10; |
| const int depth_in_ = 8; |
| const int kernel_ = 3; |
| const int depth_out_ = 16; |
| }; |
| |
| // Create small graph, run predict costs on it, make sure the costs from the |
| // summary match the hand-calculated costs. |
| TEST_F(VirtualSchedulerTest, SummaryCostTest) { |
| // Run matmul test. |
| CreateGrapplerItemWithMatmulChain(); |
| InitScheduler(); |
| auto ops_executed = RunScheduler(""); |
| Costs c = scheduler_->Summary(); |
| |
| // RandomUniform - 5 * 1s |
| // Matmuls - 4 * 2s = 8 |
| // Misc - 5 * 1us |
| // Total: 13000005 |
| EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count()); |
| EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total); |
| EXPECT_FALSE(c.inaccurate); |
| EXPECT_EQ(0, c.num_ops_with_unknown_shapes); |
| } |
| |
| // Like the above SummaryCostTest, but makes sure the stepstats timeline is |
| // correct. |
| TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) { |
| // Run matmul test. |
| CreateGrapplerItemWithMatmulChain(); |
| InitScheduler(); |
| auto ops_executed = RunScheduler(""); |
| RunMetadata metadata; |
| Costs c = scheduler_->Summary(&metadata); |
| StepStats stepstats = metadata.step_stats(); |
| EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count()); |
| EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total); |
| EXPECT_FALSE(c.inaccurate); |
| EXPECT_EQ(0, c.num_ops_with_unknown_shapes); |
| |
| // Should only be 1 device! |
| EXPECT_EQ(1, stepstats.dev_stats().size()); |
| |
| // Create a map of op name -> start and end times (micros). |
| std::map<string, std::pair<int64, int64>> start_end_times; |
| for (const auto& device_step_stats : stepstats.dev_stats()) { |
| for (const auto& stats : device_step_stats.node_stats()) { |
| int64 start = stats.all_start_micros(); |
| int64 end = start + stats.all_end_rel_micros(); |
| start_end_times[stats.node_name()] = std::pair<int64, int64>(start, end); |
| |
| // Make sure that the output properties are correct for |
| // MatMul and RandomUniform operations. |
| // We only check for dtype, and shape (excluding alloc) |
| // since alloc is not set by the virtual scheduler. |
| if (stats.timeline_label() == "MatMul" || |
| stats.timeline_label() == "RandomUniform") { |
| EXPECT_EQ(1, stats.output().size()); |
| for (const auto& output : stats.output()) { |
| EXPECT_EQ(DT_FLOAT, output.tensor_description().dtype()); |
| EXPECT_EQ(2, output.tensor_description().shape().dim().size()); |
| for (const auto& dim : output.tensor_description().shape().dim()) { |
| EXPECT_EQ(3200, dim.size()); |
| } |
| } |
| } |
| } |
| } |
| |
| // The base start_time is the time to compute RandomUniforms |
| int64 cur_time = static_cast<int64>(5000005); |
| // The increment is the execution time of one matmul. See |
| // CreateGrapplerItemWithMatmulChain for details. |
| int64 increment = static_cast<int64>(2000000); |
| auto op_names = {"ab", "abc", "abcd", "abcde"}; |
| for (const auto& op_name : op_names) { |
| int64 actual_start = start_end_times[op_name].first; |
| int64 actual_end = start_end_times[op_name].second; |
| int64 expected_start = cur_time; |
| int64 expected_end = cur_time + increment; |
| EXPECT_EQ(expected_start, actual_start); |
| EXPECT_EQ(expected_end, actual_end); |
| cur_time += increment; |
| } |
| } |
| |
| TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) { |
| // Init. |
| CreateGrapplerItemWithConv2Ds(); |
| InitScheduler(); |
| |
| // Run the scheduler. |
| auto ops_executed = RunScheduler(""); // Run all the nodes. |
| |
| // [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be |
| // executed. |
| EXPECT_EQ(8, ops_executed.size()); |
| |
| // x, y, f, c0, and c1 should be in the ops executed. |
| EXPECT_GT(ops_executed.count("x"), 0); |
| EXPECT_GT(ops_executed.count("y"), 0); |
| EXPECT_GT(ops_executed.count("f"), 0); |
| EXPECT_GT(ops_executed.count("c0"), 0); |
| EXPECT_GT(ops_executed.count("c1"), 0); |
| |
| // z and c2 shouldn't be part of it. |
| EXPECT_EQ(ops_executed.count("z"), 0); |
| EXPECT_EQ(ops_executed.count("c2"), 0); |
| |
| // Check input / output properties. |
| EXPECT_EQ(1, ops_executed["x"].op_info.outputs_size()); |
| EXPECT_EQ(1, ops_executed["y"].op_info.outputs_size()); |
| EXPECT_EQ(1, ops_executed["f"].op_info.outputs_size()); |
| EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size()); |
| EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size()); |
| } |
| |
| TEST_F(VirtualSchedulerTest, MemoryUsage) { |
| // Init. |
| CreateGrapplerItemWithAddN(); |
| InitScheduler(); |
| |
| // Run the scheduler. |
| RunScheduler(""); |
| |
| const auto* device_states = scheduler_->GetDeviceStates(); |
| const auto& cpu_state = device_states->at(kCPU0); |
| |
| // out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage |
| // is 4 x the input tensor size while executing the out node. |
| int64 one_input_node_size = 4 * 10 * 10 * 10 * 10; |
| const std::vector<string> expected_names = {"x", "y", "z", "w"}; |
| EXPECT_EQ(expected_names.size() * one_input_node_size, |
| cpu_state.max_memory_usage); |
| ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */, |
| cpu_state.mem_usage_snapshot_at_peak); |
| ExpectUnorderedMapEq( |
| {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 64)}, |
| scheduler_->GetPersistentMemoryUsage()); |
| ExpectUnorderedMapEq( |
| {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 160000)}, |
| scheduler_->GetPeakMemoryUsage()); |
| } |
| |
| TEST_F(VirtualSchedulerTest, UnnecessaryFeedNodes) { |
| CreateGrapplerItemWithUnnecessaryPlaceholderNodes(); |
| InitScheduler(); |
| |
| // Test that scheduler can run graphs with extra unnecessary feed nodes. |
| auto ops_executed = RunScheduler(""); |
| ASSERT_EQ(1, ops_executed.size()); |
| ASSERT_EQ(ops_executed.count("x"), 1); |
| } |
| |
| TEST_F(VirtualSchedulerTest, ControlDependency) { |
| // Init. |
| CreateGrapplerItemWithControlDependency(); |
| InitScheduler(); |
| |
| // Run the scheduler. |
| RunScheduler(""); |
| |
| const auto* device_states = scheduler_->GetDeviceStates(); |
| const auto& cpu_state = device_states->at(kCPU0); |
| |
| // The graph has a NoOp that takes control dependency from 7 NoOps. The peak |
| // memory usage is when executing the final NoOp. |
| int64 one_input_node_size = 4; // control dependency |
| const std::vector<string> expected_names = {"x", "y", "z", "w", |
| "u", "v", "t"}; |
| EXPECT_EQ(expected_names.size() * one_input_node_size, |
| cpu_state.max_memory_usage); |
| ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */, |
| cpu_state.mem_usage_snapshot_at_peak); |
| ExpectUnorderedMapEq( |
| {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 0)}, |
| scheduler_->GetPersistentMemoryUsage()); |
| ExpectUnorderedMapEq( |
| {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 28)}, |
| scheduler_->GetPeakMemoryUsage()); |
| } |
| |
| TEST_F(VirtualSchedulerTest, ComplexDependency) { |
| // Init. |
| CreateGrapplerItemWithBatchNorm(); |
| InitScheduler(); |
| |
| // Run the scheduler. |
| RunScheduler("bn"); |
| |
| const auto& device_states = scheduler_->GetDeviceStates(); |
| const auto& cpu_state = device_states->at(kCPU0); |
| |
| // The graph is |
| // bn = FusedBatchNorm(x, scale, offset, mean, var) |
| // z1 = bn.y + x |
| // z2 = bn.var + bn.var |
| // z3 = bn.var + bn.var |
| // z4 = control dependency from bn. |
| // Note that bn.mean doesn't have any consumer. |
| const int x_size = batch_size_ * width_ * height_ * depth_in_; |
| int64 expected_size = |
| 4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ + |
| 1 /* control dependency */); |
| EXPECT_EQ(expected_size, cpu_state.memory_usage); |
| |
| // Nodes currently in memory: bn's port -1, 0, and 2, and x's port 0. |
| std::set<std::pair<string, int>> nodes_in_memory; |
| std::transform( |
| cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(), |
| std::inserter(nodes_in_memory, nodes_in_memory.begin()), |
| [](const std::pair<const NodeDef*, int>& node_port) { |
| return std::make_pair(node_port.first->name(), node_port.second); |
| }); |
| std::set<std::pair<string, int>> expected = { |
| std::make_pair("bn", -1), |
| std::make_pair("bn", 0), |
| std::make_pair("bn", 2), |
| std::make_pair("x", 0), |
| }; |
| ExpectSetEq(expected, nodes_in_memory); |
| |
| const auto* node_states = scheduler_->GetNodeStates(); |
| const NodeState* bn_node = nullptr; |
| const NodeState* x_node = nullptr; |
| for (const auto& nodedef_node_state : *node_states) { |
| const NodeDef* node = nodedef_node_state.first; |
| const NodeState& node_state = nodedef_node_state.second; |
| if (node->name() == "bn") { |
| bn_node = &node_state; |
| } |
| if (node->name() == "x") { |
| x_node = &node_state; |
| } |
| } |
| CHECK_NOTNULL(bn_node); |
| CHECK_NOTNULL(x_node); |
| |
| ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0)); |
| ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1)); |
| ValidateNodeDefs({"z1"}, bn_node->outputs.at(0)); |
| // z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2. |
| ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2)); |
| } |
| |
| TEST_F(VirtualSchedulerTest, Variable) { |
| // Init. |
| CreateGrapplerItemWithConv2DAndVariable(); |
| InitScheduler(); |
| |
| // Run the scheduler. |
| RunScheduler(""); |
| |
| const auto* device_states = scheduler_->GetDeviceStates(); |
| const auto& cpu_state = device_states->at(kCPU0); |
| |
| // There is one Conv2D that takes x and f, but f is variable, so it should be |
| // in persistent nodes. |
| ValidateMemoryUsageSnapshot({"f", "Const/Const"}, /*port_num_expected=*/0, |
| cpu_state.persistent_nodes); |
| // Only x in peak memory usage snapshot. |
| ValidateMemoryUsageSnapshot({"x"}, /*port_num_expected=*/0, |
| cpu_state.mem_usage_snapshot_at_peak); |
| ExpectUnorderedMapEq( |
| {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 4624)}, |
| scheduler_->GetPersistentMemoryUsage()); |
| ExpectUnorderedMapEq( |
| {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 12800)}, |
| scheduler_->GetPeakMemoryUsage()); |
| } |
| |
| TEST_F(VirtualSchedulerTest, WhileLoop) { |
| // Init. |
| CreateGrapplerItemWithLoop(); |
| InitScheduler(); |
| |
| // Run the scheduler. |
| RunScheduler(""); |
| |
| // Check the timeline |
| RunMetadata metadata; |
| scheduler_->Summary(&metadata); |
| |
| // Nodes in topological order: |
| // * const, ones |
| // * while/Enter, while/Enter_1 |
| // * while/Merge, while/Merge_1 |
| // * while/Less/y |
| // * while/Less |
| // * while/LoopCond |
| // * while/Switch, while/Switch_1 |
| // * while/Identity, while/Identity_1, while/Exit, while/Exit_1 |
| // * while/add/y, while/concat/axis |
| // * while/add, while/concat |
| // * while/NextIteration, while/NextIteration_1 |
| |
| int num_next_iteration = 0; |
| int num_next_iteration_1 = 0; |
| int num_exit = 0; |
| int num_exit_1 = 0; |
| int64 next_iter_start_micro; |
| int64 next_iter_1_start_micro; |
| int64 exit_start_micro; |
| int64 exit_1_start_micro; |
| |
| std::unordered_map<string, int64> start_times; |
| for (const auto& device_step_stats : metadata.step_stats().dev_stats()) { |
| for (const auto& stats : device_step_stats.node_stats()) { |
| start_times[stats.node_name()] = stats.all_start_micros(); |
| if (stats.node_name() == "while/NextIteration") { |
| ++num_next_iteration; |
| next_iter_start_micro = stats.all_start_micros(); |
| } else if (stats.node_name() == "while/NextIteration_1") { |
| ++num_next_iteration_1; |
| next_iter_1_start_micro = stats.all_start_micros(); |
| } else if (stats.node_name() == "while/Exit") { |
| ++num_exit; |
| exit_start_micro = stats.all_start_micros(); |
| } else if (stats.node_name() == "while/Exit_1") { |
| ++num_exit_1; |
| exit_1_start_micro = stats.all_start_micros(); |
| } |
| } |
| } |
| |
| // Make sure we went though the body of the loop once, and that the output of |
| // the loop was scheduled as well. |
| EXPECT_EQ(1, num_next_iteration); |
| EXPECT_EQ(1, num_next_iteration_1); |
| EXPECT_EQ(1, num_exit); |
| EXPECT_EQ(1, num_exit_1); |
| |
| // Start times of while/NextIteration and while/NextIteration_1 should be |
| // different, so should be those of while/Exit and while/Exit_1. |
| EXPECT_NE(next_iter_start_micro, next_iter_1_start_micro); |
| EXPECT_NE(exit_start_micro, exit_1_start_micro); |
| |
| // Check dependency among the nodes; no matter what scheduling mechanism we |
| // use, the scheduled ops should follow these dependency chains. |
| // Note that currently, VirtualScheduler executes while/Merge twice; hence, |
| // we're not testing dependency chains related to while/Merge. |
| // TODO(dyoon): after fixing while loop behavior correctly (run nodes in the |
| // order of Enter, Merge, ...loop condition ..., ... loop body ..., |
| // NextIteration, Merge, ... loop condition ..., Exit), re-enable dependency |
| // chaining test w/ Merge nodes. |
| ValidateDependencyChain( |
| start_times, |
| {"Const", "while/Enter", // "while/Merge", |
| "while/Less/y", "while/Less", "while/LoopCond", "while/Switch", |
| "while/Identity", "while/add/y", "while/add", "while/NextIteration"}); |
| // ValidateDependencyChain(start_times, {"while/Merge", "while/Less"}); |
| ValidateDependencyChain(start_times, |
| {"ones", "while/Enter_1", // "while/Merge_1", |
| "while/Switch_1", "while/Identity_1", "while/concat", |
| "while/NextIteration_1"}); |
| ValidateDependencyChain(start_times, {"while/Switch", "while/Exit"}); |
| ValidateDependencyChain( |
| start_times, {"while/Identity", "while/concat/axis", "while/concat"}); |
| ValidateDependencyChain(start_times, {"while/Identity", "while/add"}); |
| ValidateDependencyChain(start_times, {"while/Switch_1", "while/Exit_1"}); |
| } |
| |
| TEST_F(VirtualSchedulerTest, AnnotatedWhileLoop) { |
| { |
| // Init. |
| CreateGrapplerItemWithLoop(); |
| InitScheduler(); |
| |
| // Runs the scheduler. |
| RunScheduler(""); |
| Costs c = scheduler_->Summary(); |
| |
| EXPECT_EQ(23, c.execution_time.asMicroSeconds().count()); |
| // Both while/Merge and while/Merge_1 are scheduled twice. |
| EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total); |
| EXPECT_FALSE(c.inaccurate); |
| EXPECT_EQ(0, c.num_ops_with_unknown_shapes); |
| } |
| |
| { |
| // Init. |
| CreateGrapplerItemWithLoopAnnotated(); |
| InitScheduler(); |
| |
| // Runs the scheduler. |
| RunScheduler(""); |
| Costs c = scheduler_->Summary(); |
| |
| // The costs for Merge is accumulated twice for execution_count times, but |
| // since Merge's cost is minimal, we keep this behavior here. |
| EXPECT_EQ(178, c.execution_time.asMicroSeconds().count()); |
| // Both while/Merge and while/Merge_1 are scheduled twice. |
| EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total); |
| EXPECT_FALSE(c.inaccurate); |
| EXPECT_EQ(0, c.num_ops_with_unknown_shapes); |
| } |
| } |
| |
| TEST_F(VirtualSchedulerTest, Condition) { |
| // Without annotation. |
| { |
| // Inits. |
| CreateGrapplerItemWithCondition(); |
| InitScheduler(); |
| |
| // Runs the scheduler. |
| RunScheduler(""); |
| RunMetadata metadata; |
| Costs c = scheduler_->Summary(&metadata); |
| |
| // Nodes in topological order: a/Less, Switch, First/Second, Merge. |
| int num_a = 0; |
| int num_less = 0; |
| int num_switch = 0; |
| int num_first = 0; |
| int num_second = 0; |
| int num_merge = 0; |
| |
| for (const auto& device_step_stats : metadata.step_stats().dev_stats()) { |
| for (const auto& stats : device_step_stats.node_stats()) { |
| if (stats.node_name() == "a") { |
| ++num_a; |
| } else if (stats.node_name() == "Less") { |
| ++num_less; |
| } else if (stats.node_name() == "Switch") { |
| ++num_switch; |
| } else if (stats.node_name() == "First") { |
| ++num_first; |
| } else if (stats.node_name() == "Second") { |
| ++num_second; |
| } else if (stats.node_name() == "Merge") { |
| ++num_merge; |
| } |
| } |
| } |
| |
| EXPECT_EQ(1, num_a); |
| EXPECT_EQ(1, num_less); |
| EXPECT_EQ(1, num_switch); |
| EXPECT_EQ(1, num_first); |
| EXPECT_EQ(1, num_second); |
| EXPECT_EQ(2, num_merge); |
| |
| EXPECT_EQ(7, c.execution_time.asMicroSeconds().count()); |
| // Merge is executed twice. |
| EXPECT_EQ(grappler_item_->graph.node_size() + 1, c.num_ops_total); |
| EXPECT_FALSE(c.inaccurate); |
| EXPECT_EQ(0, c.num_ops_with_unknown_shapes); |
| } |
| |
| // With annotation. |
| { |
| // Inits. |
| CreateGrapplerItemWithCondition(); |
| |
| // Annotates the Switch node. |
| for (auto& node : *grappler_item_->graph.mutable_node()) { |
| if (node.name() == "Switch") { |
| AttrValue attr_output_info; |
| // Adds one output slot 0 so that Second shouldn't be executed. |
| (*attr_output_info.mutable_list()).add_i(0); |
| AddNodeAttr(kOutputSlots, attr_output_info, &node); |
| } |
| } |
| |
| InitScheduler(); |
| |
| // Runs the scheduler. |
| RunScheduler(""); |
| RunMetadata metadata; |
| Costs c = scheduler_->Summary(&metadata); |
| |
| // Nodes in topological order: a/Less, Switch, Merge |
| int num_a = 0; |
| int num_less = 0; |
| int num_switch = 0; |
| int num_first = 0; |
| int num_second = 0; |
| int num_merge = 0; |
| |
| for (const auto& device_step_stats : metadata.step_stats().dev_stats()) { |
| for (const auto& stats : device_step_stats.node_stats()) { |
| if (stats.node_name() == "a") { |
| ++num_a; |
| } else if (stats.node_name() == "Less") { |
| ++num_less; |
| } else if (stats.node_name() == "Switch") { |
| ++num_switch; |
| } else if (stats.node_name() == "First") { |
| ++num_first; |
| } else if (stats.node_name() == "Second") { |
| ++num_second; |
| } else if (stats.node_name() == "Merge") { |
| ++num_merge; |
| } |
| } |
| } |
| |
| EXPECT_EQ(1, num_a); |
| EXPECT_EQ(1, num_less); |
| EXPECT_EQ(1, num_switch); |
| EXPECT_EQ(1, num_first); |
| EXPECT_EQ(0, num_second); |
| EXPECT_EQ(1, num_merge); |
| |
| EXPECT_EQ(5, c.execution_time.asMicroSeconds().count()); |
| // Second is not executed. |
| EXPECT_EQ(grappler_item_->graph.node_size() - 1, c.num_ops_total); |
| EXPECT_FALSE(c.inaccurate); |
| EXPECT_EQ(0, c.num_ops_with_unknown_shapes); |
| } |
| } |
| |
| TEST_F(VirtualSchedulerTest, InterDeviceTransfer) { |
| // Init. |
| CreateGrapplerItemWithInterDeviceTransfers(); |
| InitScheduler(); |
| |
| // Run the scheduler. |
| auto ops_executed = RunScheduler(""); |
| |
| // Helper lambda to extract port num from _Send and _Recv op name. |
| auto get_port_num = [](const string& name) -> int { |
| if (name.find("bn_0") != string::npos) { |
| return 0; |
| } else if (name.find("bn_1") != string::npos) { |
| return 1; |
| } else if (name.find("bn_2") != string::npos) { |
| return 2; |
| } else if (name.find("bn_minus1") != string::npos) { |
| return -1; |
| } |
| return -999; |
| }; |
| |
| // Reorganize ops_executed for further testing. |
| std::unordered_map<string, int> op_count; |
| std::unordered_map<int, string> recv_op_names; |
| std::unordered_map<int, string> send_op_names; |
| for (const auto& x : ops_executed) { |
| const auto& name = x.first; |
| const auto& node_info = x.second; |
| const auto& op = node_info.op_info.op(); |
| if (op == kRecv) { |
| recv_op_names[get_port_num(name)] = name; |
| } else if (op == kSend) { |
| send_op_names[get_port_num(name)] = name; |
| } |
| op_count[op]++; |
| } |
| |
| // Same number of _Send and _Recv. |
| EXPECT_EQ(op_count.at(kSend), op_count.at(kRecv)); |
| |
| // Expect 4 Send and Recvs each: port 0, 1, and, 2, and control dependency. |
| EXPECT_EQ(op_count.at(kRecv), 4); |
| EXPECT_EQ(op_count.at(kSend), 4); |
| |
| // Helper lambda for extracting output Tensor size. |
| auto get_output_size = [this, ops_executed](const string& name) -> int64 { |
| const auto& output_properties_ = ops_executed.at(name).op_info.outputs(); |
| std::vector<OpInfo::TensorProperties> output_properties; |
| for (const auto& output_property : output_properties_) { |
| output_properties.push_back(output_property); |
| } |
| return CalculateOutputSize(output_properties, 0); |
| }; |
| |
| // Validate transfer size. |
| // Batchnorm output y is 4D vector: batch x width x width x depth. |
| int input_size = 4 * batch_size_ * width_ * height_ * depth_in_; |
| EXPECT_EQ(get_output_size(recv_op_names[0]), input_size); |
| EXPECT_EQ(get_output_size(send_op_names[0]), input_size); |
| // Mean and vars are 1-D vector with size depth_in_. |
| EXPECT_EQ(get_output_size(recv_op_names[1]), 4 * depth_in_); |
| EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_); |
| EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_); |
| EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_); |
| // Control dependency size is 4B. |
| EXPECT_EQ(get_output_size(recv_op_names[-1]), 4); |
| EXPECT_EQ(get_output_size(send_op_names[-1]), 4); |
| } |
| |
| TEST_F(VirtualSchedulerTest, GraphWithSendRecv) { |
| // Init. |
| CreateGrapplerItemWithSendRecv(); |
| InitScheduler(); |
| |
| // Run the scheduler. |
| auto ops_executed = RunScheduler(""); |
| |
| EXPECT_GT(ops_executed.count("Const"), 0); |
| EXPECT_GT(ops_executed.count("Send"), 0); |
| EXPECT_GT(ops_executed.count("Recv"), 0); |
| } |
| |
| TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) { |
| // Init. |
| CreateGrapplerItemWithSendRecv(); |
| // Change Recv node's device so that Send and Recv are placed on different |
| // devices. |
| auto& graph = grappler_item_->graph; |
| const string recv_device = kCPU1; |
| for (int i = 0; i < graph.node_size(); i++) { |
| auto* node = graph.mutable_node(i); |
| if (node->name() == "Recv") { |
| node->set_device(recv_device); |
| auto* attr = node->mutable_attr(); |
| (*attr)["recv_device"].set_s(recv_device); |
| } else if (node->name() == "Send") { |
| auto* attr = node->mutable_attr(); |
| (*attr)["recv_device"].set_s(recv_device); |
| } |
| } |
| InitScheduler(); |
| |
| // Run the scheduler. |
| auto ops_executed = RunScheduler(""); |
| |
| // Expect Const, Send, Recv, and VirtualScheduler created Send and Recv ops. |
| EXPECT_GT(ops_executed.count("Const"), 0); |
| EXPECT_GT(ops_executed.count("Send"), 0); |
| EXPECT_GT(ops_executed.count("Send_Send_0_from_/job_localhost/replica_0/" |
| "task_0/cpu_0_to_/job_localhost" |
| "/replica_0/task_0/cpu_1"), |
| 0); |
| EXPECT_GT(ops_executed.count( |
| "Recv_Send_0_on_/job_localhost/replica_0/task_0/cpu_1"), |
| 0); |
| EXPECT_GT(ops_executed.count("Recv"), 0); |
| } |
| |
| TEST_F(VirtualSchedulerTest, GraphWihtOnlyRecv) { |
| // Init. |
| CreateGrapplerItemWithRecvWithoutSend(); |
| InitScheduler(); |
| |
| // Run the scheduler. |
| auto ops_executed = RunScheduler(""); |
| |
| // Recv without Send will be treated as initially ready node. |
| EXPECT_GT(ops_executed.count("Recv"), 0); |
| } |
| |
| TEST_F(VirtualSchedulerTest, AddMergeSwitch) { |
| // Override scheduler_ with CompositeNodeManager. |
| scheduler_ = absl::make_unique<TestVirtualScheduler>( |
| /*use_static_shapes=*/true, |
| /*use_aggressive_shape_inference=*/true, &composite_node_manager_, |
| cluster_.get()); |
| CreateGrapplerItemWithSwitchMergeInput(); |
| InitScheduler(); |
| |
| // pred --+ z --+ |
| // | | |
| // V V |
| // x -> Switch --------> Merge ---> Add --> y |
| // | ^ |
| // | | |
| // +-----> Add -----+ |
| // ^ |
| // | |
| // b --------------+ |
| |
| // Run the scheduler. The current VirtualScheduler, w/o annotation, triggers |
| // both outputs of Switch; then Merge (as long as one input is ready, it's z |
| // is ready, if we just use num_inputs_ready counter, the final Add becomes |
| // ready. possible to skip scheduling z. (Need to use CompositeNodeManager |
| // to test this case). |
| auto ops_executed = RunScheduler(""); |
| |
| EXPECT_GT(ops_executed.count("z"), 0); |
| } |
| |
| TEST_F(VirtualSchedulerTest, AddFromOneTensor) { |
| CreateGrapplerItemWithAddFromOneTensor(); |
| InitScheduler(); |
| |
| // x -+----> Add --> y |
| // | ^ |
| // | | |
| // +-------+ |
| |
| // Run the scheduler. |
| auto ops_executed = RunScheduler(""); |
| EXPECT_GT(ops_executed.count("y"), 0); |
| EXPECT_GT(ops_executed.count("x"), 0); |
| } |
| |
| } // namespace |
| } // end namespace grappler |
| } // end namespace tensorflow |