| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #include "tensorflow/cc/ops/const_op.h" |
| #include "tensorflow/cc/ops/sendrecv_ops.h" |
| #include "tensorflow/cc/ops/standard_ops.h" |
| #include "tensorflow/core/framework/tensor_testutil.h" |
| #include "tensorflow/core/lib/core/status_test_util.h" |
| #include "tensorflow/core/lib/io/path.h" |
| #include "tensorflow/core/lib/strings/strcat.h" |
| #include "tensorflow/core/platform/test.h" |
| #include "tensorflow/core/platform/test_benchmark.h" |
| #include "tensorflow/core/public/session.h" |
| #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" |
| #include "tensorflow/tools/graph_transforms/transform_utils.h" |
| |
| namespace tensorflow { |
| namespace graph_transforms { |
| |
| // Declarations so we don't need a public header. |
| Status SparsifyGather(const GraphDef& input_graph_def, |
| const TransformFuncContext& context, |
| GraphDef* output_graph_def); |
| Status ReadTensorFromCheckpoint( |
| const string& tensor_name, const std::unique_ptr<BundleReader>& ckpt_reader, |
| const string& shape_and_slice, Tensor* tensor); |
| |
| class SparsifyGatherTest : public ::testing::Test { |
| protected: |
| NodeDef* CreateNode(const StringPiece name, const StringPiece op, |
| const std::vector<NodeDef*>& inputs, GraphDef* graph_def, |
| bool control_dep = false) { |
| NodeDef* node_def = graph_def->add_node(); |
| node_def->set_name(string(name)); |
| node_def->set_op(string(op)); |
| if (!control_dep) { |
| std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) { |
| node_def->add_input(input->name()); |
| }); |
| } else { |
| std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) { |
| node_def->add_input(strings::StrCat("^", input->name())); |
| }); |
| } |
| return node_def; |
| } |
| |
| void MakeGather(StringPiece name, bool gather_v2, NodeDef* params, |
| NodeDef* indices, GraphDef* graph_def) { |
| if (gather_v2) { |
| NodeDef* axis_node = |
| CreateNode(strings::StrCat(name, "_axis"), "Const", {}, graph_def); |
| Tensor axis_t(DT_INT32, TensorShape({})); |
| axis_t.scalar<int32>()() = 0; |
| SetNodeTensorAttr<int32>("value", axis_t, axis_node); |
| CreateNode(name, "GatherV2", {params, indices, axis_node}, graph_def); |
| } else { |
| CreateNode(name, "Gather", {params, indices}, graph_def); |
| } |
| } |
| |
| void TestSinglePartition(bool gather_v2, bool include_shared_init, |
| bool test_variable, bool test_kept_concat, |
| const string& shared_init_name = "group_deps") { |
| GraphDef graph_def; |
| |
| const auto checkpoint_path = |
| io::JoinPath(testing::TmpDir(), "checkpoint_single"); |
| // Build the graph. |
| NodeDef* input_node = CreateNode("ids", "Const", {}, &graph_def); |
| NodeDef* w_node; |
| NodeDef* zeros_const; |
| NodeDef* zeros_shape; |
| NodeDef* zeros_node; |
| NodeDef* assign_node; |
| |
| Tensor weights(DT_FLOAT, TensorShape({4, 1})); |
| test::FillValues<float>(&weights, {0.2, 0.000001, 1.2, 0.001}); |
| |
| if (!test_variable) { |
| w_node = CreateNode("w/part_1", "Const", {}, &graph_def); |
| SetNodeTensorAttr<float>("value", weights, w_node); |
| } else { |
| w_node = CreateNode("w/part_1", "VariableV2", {}, &graph_def); |
| |
| zeros_shape = CreateNode("w/part_1/Initializer/zeros/shape_as_tensor", |
| "Const", {}, &graph_def); |
| zeros_const = CreateNode("w/part_1/Initializer/zeros/Const", "Const", {}, |
| &graph_def); |
| zeros_node = CreateNode("w/part_1/Initializer/zeros", "Fill", |
| {zeros_shape, zeros_const}, &graph_def); |
| assign_node = CreateNode("w/part_1/Assign", "Assign", |
| {w_node, zeros_node}, &graph_def); |
| |
| NodeDef* save_const_node = |
| CreateNode("save/Const", "Const", {}, &graph_def); |
| |
| Tensor tensor_names_values(DT_STRING, TensorShape({1})); |
| test::FillValues<string>(&tensor_names_values, {"w"}); |
| NodeDef* tensor_names_node = |
| CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def); |
| SetNodeTensorAttr<string>("value", tensor_names_values, |
| tensor_names_node); |
| |
| NodeDef* tensor_shapes_slices_node = CreateNode( |
| "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def); |
| Tensor shapes_slices_val(DT_STRING, TensorShape({1})); |
| shapes_slices_val.flat<tstring>()(0) = "4 1 0,4:0,1"; |
| SetNodeTensorAttr<string>("value", shapes_slices_val, |
| tensor_shapes_slices_node); |
| |
| NodeDef* restore_node = CreateNode( |
| "save/RestoreV2", "RestoreV2", |
| {save_const_node, tensor_names_node, tensor_shapes_slices_node}, |
| &graph_def); |
| CreateNode("save/Assign", "Assign", {w_node, restore_node}, &graph_def); |
| |
| BundleWriter writer(Env::Default(), checkpoint_path); |
| TF_ASSERT_OK(writer.Add("w", weights)); |
| TF_ASSERT_OK(writer.Finish()); |
| } |
| SetNodeAttr("dtype", DT_FLOAT, w_node); |
| |
| NodeDef* identity_node = |
| CreateNode("w/read", "Identity", {w_node}, &graph_def); |
| MakeGather("gather", gather_v2, identity_node, input_node, &graph_def); |
| if (include_shared_init) { |
| if (!test_variable) { |
| CreateNode(shared_init_name, "NoOp", {}, &graph_def); |
| } else { |
| CreateNode(shared_init_name, "NoOp", {assign_node}, &graph_def, true); |
| } |
| } |
| |
| NodeDef* concat_axis_node = |
| CreateNode("linear/concat/axis", "Const", {}, &graph_def); |
| NodeDef* concat_input_node = |
| CreateNode("concat/input/node", "Const", {}, &graph_def); |
| NodeDef* concat_node = nullptr; |
| if (!test_kept_concat) { |
| concat_node = CreateNode( |
| "concat/node", "ConcatV2", |
| {identity_node, concat_input_node, concat_axis_node}, &graph_def); |
| SetNodeAttr("N", 2, concat_node); |
| } else { |
| NodeDef* concat_input_node_2 = |
| CreateNode("concat/input/node_2", "Const", {}, &graph_def); |
| concat_node = CreateNode("concat/node", "ConcatV2", |
| {identity_node, concat_input_node, |
| concat_input_node_2, concat_axis_node}, |
| &graph_def); |
| SetNodeAttr("N", 3, concat_node); |
| } |
| |
| // Run the op. |
| GraphDef result; |
| TransformFuncContext context; |
| context.input_names = {"ids"}; |
| context.output_names = {"gather"}; |
| if (test_variable) { |
| context.params["input_checkpoint"] = {checkpoint_path}; |
| } |
| if (shared_init_name != "group_deps") { |
| context.params["group_init_node"] = {shared_init_name}; |
| } |
| TF_ASSERT_OK(SparsifyGather(graph_def, context, &result)); |
| |
| // Validation begins. |
| std::map<string, const NodeDef*> node_lookup; |
| MapNamesToNodes(result, &node_lookup); |
| |
| // Check nodes. |
| EXPECT_EQ(0, |
| node_lookup.count("w/part_1/Initializer/zeros/shape_as_tensor")); |
| EXPECT_EQ(0, node_lookup.count("w/part_1/Initializer/zeros/Const")); |
| EXPECT_EQ(0, node_lookup.count("w/part_1/Initializer/zeros")); |
| EXPECT_EQ(0, node_lookup.count("w/part_1/Assign")); |
| |
| EXPECT_EQ(1, node_lookup.count("ids")); |
| EXPECT_EQ("Const", node_lookup.at("ids")->op()); |
| |
| EXPECT_EQ(1, node_lookup.count("concat/node")); |
| |
| if (!test_kept_concat) { |
| EXPECT_EQ(0, node_lookup.count("linear/concat/axis")); |
| EXPECT_EQ("Identity", node_lookup.at("concat/node")->op()); |
| EXPECT_EQ(1, node_lookup.at("concat/node")->input_size()); |
| EXPECT_EQ("concat/input/node", node_lookup.at("concat/node")->input(0)); |
| } else { |
| EXPECT_EQ(1, node_lookup.count("linear/concat/axis")); |
| EXPECT_EQ("ConcatV2", node_lookup.at("concat/node")->op()); |
| EXPECT_EQ(3, node_lookup.at("concat/node")->input_size()); |
| EXPECT_EQ("concat/input/node", node_lookup.at("concat/node")->input(0)); |
| EXPECT_EQ("concat/input/node_2", node_lookup.at("concat/node")->input(1)); |
| EXPECT_EQ("linear/concat/axis", node_lookup.at("concat/node")->input(2)); |
| EXPECT_EQ(2, node_lookup.at("concat/node")->attr().at("N").i()); |
| } |
| |
| EXPECT_EQ(1, node_lookup.count("w/part_1/indices")); |
| EXPECT_EQ("Const", node_lookup.at("w/part_1/indices")->op()); |
| Tensor expected_indices_tensor(DT_INT64, TensorShape({3})); |
| test::FillValues<int64>(&expected_indices_tensor, {0, 2, 3}); |
| test::ExpectTensorEqual<int64>( |
| expected_indices_tensor, |
| GetNodeTensorAttr(*(node_lookup.at("w/part_1/indices")), "value")); |
| |
| EXPECT_EQ(1, node_lookup.count("w/part_1/values")); |
| EXPECT_EQ("Const", node_lookup.at("w/part_1/values")->op()); |
| Tensor expected_values_tensor(DT_FLOAT, TensorShape({3})); |
| test::FillValues<float>(&expected_values_tensor, {0.2, 1.2, 0.001}); |
| test::ExpectTensorNear<float>( |
| expected_values_tensor, |
| GetNodeTensorAttr(*(node_lookup.at("w/part_1/values")), "value"), 1e-5); |
| |
| EXPECT_EQ(1, node_lookup.count("w/part_1/HashTable")); |
| EXPECT_EQ("HashTable", node_lookup.at("w/part_1/HashTable")->op()); |
| |
| EXPECT_EQ(1, node_lookup.count("w/part_1/InitializeTable")); |
| EXPECT_EQ("InitializeTable", |
| node_lookup.at("w/part_1/InitializeTable")->op()); |
| |
| // Nodes in "gather" scope. |
| EXPECT_EQ(1, node_lookup.count("gather/LookupTableFind")); |
| EXPECT_EQ("LookupTableFind", |
| node_lookup.at("gather/LookupTableFind")->op()); |
| |
| EXPECT_EQ(1, node_lookup.count("gather/Const")); |
| EXPECT_EQ("Const", node_lookup.at("gather/Const")->op()); |
| Tensor expected_gather_default_tensor(DT_FLOAT, TensorShape({})); |
| test::FillValues<float>(&expected_gather_default_tensor, {0.0}); |
| test::ExpectTensorNear<float>( |
| expected_gather_default_tensor, |
| GetNodeTensorAttr(*(node_lookup.at("gather/Const")), "value"), 1e-5); |
| |
| EXPECT_EQ(1, node_lookup.count("gather/ExpandDims/Const")); |
| EXPECT_EQ("Const", node_lookup.at("gather/ExpandDims/Const")->op()); |
| Tensor expected_expand_dims_tensor(DT_INT32, TensorShape({})); |
| test::FillValues<int32>(&expected_expand_dims_tensor, {-1}); |
| test::ExpectTensorEqual<int32>( |
| expected_expand_dims_tensor, |
| GetNodeTensorAttr(*(node_lookup.at("gather/ExpandDims/Const")), |
| "value")); |
| |
| EXPECT_EQ(1, node_lookup.count("gather")); |
| EXPECT_EQ("ExpandDims", node_lookup.at("gather")->op()); |
| |
| EXPECT_EQ(1, node_lookup.count(shared_init_name)); |
| EXPECT_EQ("NoOp", node_lookup.at(shared_init_name)->op()); |
| |
| // Check connections |
| EXPECT_EQ("w/part_1/HashTable", |
| node_lookup.at("w/part_1/InitializeTable")->input(0)); |
| EXPECT_EQ("w/part_1/indices", |
| node_lookup.at("w/part_1/InitializeTable")->input(1)); |
| EXPECT_EQ("w/part_1/values", |
| node_lookup.at("w/part_1/InitializeTable")->input(2)); |
| |
| EXPECT_EQ("w/part_1/HashTable", |
| node_lookup.at("gather/LookupTableFind")->input(0)); |
| EXPECT_EQ("ids", node_lookup.at("gather/LookupTableFind")->input(1)); |
| EXPECT_EQ("gather/Const", |
| node_lookup.at("gather/LookupTableFind")->input(2)); |
| |
| EXPECT_EQ("gather/LookupTableFind", node_lookup.at("gather")->input(0)); |
| |
| // Check control dependency. |
| EXPECT_NE(std::find(node_lookup.at(shared_init_name)->input().begin(), |
| node_lookup.at(shared_init_name)->input().end(), |
| "^w/part_1/InitializeTable"), |
| node_lookup.at(shared_init_name)->input().end()); |
| EXPECT_EQ(1, node_lookup.at(shared_init_name)->input().size()); |
| } |
| |
| void TestMultiPartition(bool gather_v2, bool include_shared_init, |
| bool test_variable, |
| const string& shared_init_name = "group_deps") { |
| // The 'ids' node is served input for two 'Gather's. |
| GraphDef graph_def; |
| |
| const auto checkpoint_path = |
| io::JoinPath(testing::TmpDir(), "checkpoint_multiple"); |
| // Build Graph: |
| // Shared input node |
| NodeDef* input_node = CreateNode("ids", "Const", {}, &graph_def); |
| |
| // Two partitions |
| NodeDef* w_node1; |
| NodeDef* w_node2; |
| NodeDef* zeros_const1; |
| NodeDef* zeros_shape1; |
| NodeDef* zeros_node1; |
| NodeDef* zeros_const2; |
| NodeDef* zeros_shape2; |
| NodeDef* zeros_node2; |
| NodeDef* assign_node1; |
| NodeDef* assign_node2; |
| |
| Tensor weights(DT_FLOAT, TensorShape({4, 1})); |
| test::FillValues<float>(&weights, {0.2, 0.000001, 1.2, 0.001}); |
| if (!test_variable) { |
| w_node1 = CreateNode("w1/part_1", "Const", {}, &graph_def); |
| w_node2 = CreateNode("w2/part_1", "Const", {}, &graph_def); |
| SetNodeTensorAttr<float>("value", weights, w_node1); |
| SetNodeTensorAttr<float>("value", weights, w_node2); |
| } else { |
| NodeDef* save_const_node = |
| CreateNode("save/Const", "Const", {}, &graph_def); |
| |
| NodeDef* tensor_names_node = |
| CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def); |
| Tensor tensor_names_values(DT_STRING, TensorShape({2})); |
| test::FillValues<string>(&tensor_names_values, {"w1", "w2"}); |
| SetNodeTensorAttr<string>("value", tensor_names_values, |
| tensor_names_node); |
| |
| NodeDef* tensor_shapes_slices_node = CreateNode( |
| "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def); |
| Tensor shapes_slices_val(DT_STRING, TensorShape({2})); |
| shapes_slices_val.flat<tstring>()(0) = "4 1 0,4:0,1"; |
| shapes_slices_val.flat<tstring>()(1) = "4 1 0,4:0,1"; |
| SetNodeTensorAttr<string>("value", shapes_slices_val, |
| tensor_shapes_slices_node); |
| |
| NodeDef* restore_node = CreateNode( |
| "save/RestoreV2", "RestoreV2", |
| {save_const_node, tensor_names_node, tensor_shapes_slices_node}, |
| &graph_def); |
| |
| w_node1 = CreateNode("w1/part_1", "VariableV2", {}, &graph_def); |
| |
| zeros_shape1 = CreateNode("w1/part_1/Initializer/zeros/shape_as_tensor", |
| "Const", {}, &graph_def); |
| zeros_const1 = CreateNode("w1/part_1/Initializer/zeros/Const", "Const", |
| {}, &graph_def); |
| zeros_node1 = CreateNode("w1/part_1/Initializer/zeros", "Fill", |
| {zeros_shape1, zeros_const1}, &graph_def); |
| assign_node1 = CreateNode("w1/part_1/Assign", "Assign", |
| {w_node1, zeros_node1}, &graph_def); |
| |
| CreateNode("save/Assign", "Assign", {w_node1, restore_node}, &graph_def); |
| |
| w_node2 = CreateNode("w2/part_1", "VariableV2", {}, &graph_def); |
| zeros_shape2 = CreateNode("w2/part_1/Initializer/zeros/shape_as_tensor", |
| "Const", {}, &graph_def); |
| zeros_const2 = CreateNode("w2/part_1/Initializer/zeros/Const", "Const", |
| {}, &graph_def); |
| zeros_node2 = CreateNode("w2/part_1/Initializer/zeros", "Fill", |
| {zeros_shape2, zeros_const2}, &graph_def); |
| assign_node2 = CreateNode("w2/part_1/Assign", "Assign", |
| {w_node2, zeros_node2}, &graph_def); |
| |
| CreateNode("save/Assign_1", "Assign", {w_node2, restore_node}, |
| &graph_def); |
| |
| BundleWriter writer(Env::Default(), checkpoint_path); |
| TF_ASSERT_OK(writer.Add("w1", weights)); |
| TF_ASSERT_OK(writer.Add("w2", weights)); |
| TF_ASSERT_OK(writer.Finish()); |
| } |
| SetNodeAttr("dtype", DT_FLOAT, w_node1); |
| SetNodeAttr("dtype", DT_FLOAT, w_node2); |
| |
| NodeDef* identity_node1 = |
| CreateNode("w1/part_1/read", "Identity", {w_node1}, &graph_def); |
| NodeDef* identity_node2 = |
| CreateNode("w2/part_1/read", "Identity", {w_node2}, &graph_def); |
| MakeGather("gather1", gather_v2, identity_node1, input_node, &graph_def); |
| MakeGather("gather2", gather_v2, identity_node2, input_node, &graph_def); |
| |
| NodeDef* concat_axis_node = |
| CreateNode("linear/concat/axis", "Const", {}, &graph_def); |
| NodeDef* concat_node = CreateNode( |
| "concat/node", "ConcatV2", |
| {identity_node1, identity_node2, concat_axis_node}, &graph_def); |
| SetNodeAttr("N", 2, concat_node); |
| |
| // Shared init node |
| if (include_shared_init) { |
| if (!test_variable) { |
| CreateNode(shared_init_name, "NoOp", {}, &graph_def); |
| } else { |
| CreateNode(shared_init_name, "NoOp", {assign_node1, assign_node2}, |
| &graph_def, true); |
| } |
| } |
| |
| // Run the op. |
| GraphDef result; |
| TransformFuncContext context; |
| context.input_names = {"ids"}; |
| context.output_names = {"gather1", "gather2"}; |
| if (test_variable) { |
| context.params["input_checkpoint"] = {checkpoint_path}; |
| } |
| if (shared_init_name != "group_deps") { |
| context.params["group_init_node"] = {shared_init_name}; |
| } |
| TF_ASSERT_OK(SparsifyGather(graph_def, context, &result)); |
| |
| // Validation begins. |
| std::map<string, const NodeDef*> node_lookup; |
| MapNamesToNodes(result, &node_lookup); |
| |
| // Check nodes. |
| EXPECT_EQ(0, |
| node_lookup.count("w1/part_1/Initializer/zeros/shape_as_tensor")); |
| EXPECT_EQ(0, node_lookup.count("w1/part_1/Initializer/zeros/Const")); |
| EXPECT_EQ(0, node_lookup.count("w1/part_1/Initializer/zeros")); |
| EXPECT_EQ(0, node_lookup.count("w1/part_1/Assign")); |
| EXPECT_EQ(0, |
| node_lookup.count("w2/part_1/Initializer/zeros/shape_as_tensor")); |
| EXPECT_EQ(0, node_lookup.count("w2/part_1/Initializer/zeros/Const")); |
| EXPECT_EQ(0, node_lookup.count("w2/part_1/Initializer/zeros")); |
| EXPECT_EQ(0, node_lookup.count("w2/part_1/Assign")); |
| EXPECT_EQ(1, node_lookup.count("ids")); |
| EXPECT_EQ("Const", node_lookup.at("ids")->op()); |
| |
| EXPECT_EQ(1, node_lookup.count(shared_init_name)); |
| EXPECT_EQ("NoOp", node_lookup.at(shared_init_name)->op()); |
| |
| EXPECT_EQ(1, node_lookup.count("w1/part_1/indices")); |
| EXPECT_EQ("Const", node_lookup.at("w1/part_1/indices")->op()); |
| Tensor expected_indices_tensor1(DT_INT64, TensorShape({3})); |
| test::FillValues<int64>(&expected_indices_tensor1, {0, 2, 3}); |
| test::ExpectTensorEqual<int64>( |
| expected_indices_tensor1, |
| GetNodeTensorAttr(*(node_lookup.at("w1/part_1/indices")), "value")); |
| |
| EXPECT_EQ(1, node_lookup.count("w1/part_1/values")); |
| EXPECT_EQ("Const", node_lookup.at("w1/part_1/values")->op()); |
| Tensor expected_values_tensor1(DT_FLOAT, TensorShape({3})); |
| test::FillValues<float>(&expected_values_tensor1, {0.2, 1.2, 0.001}); |
| test::ExpectTensorNear<float>( |
| expected_values_tensor1, |
| GetNodeTensorAttr(*(node_lookup.at("w1/part_1/values")), "value"), |
| 1e-5); |
| |
| EXPECT_EQ(1, node_lookup.count("w1/part_1/HashTable")); |
| EXPECT_EQ("HashTable", node_lookup.at("w1/part_1/HashTable")->op()); |
| |
| EXPECT_EQ(1, node_lookup.count("w1/part_1/InitializeTable")); |
| EXPECT_EQ("InitializeTable", |
| node_lookup.at("w1/part_1/InitializeTable")->op()); |
| |
| // Nodes in "gather1" scope. |
| EXPECT_EQ(1, node_lookup.count("gather1/LookupTableFind")); |
| EXPECT_EQ("LookupTableFind", |
| node_lookup.at("gather1/LookupTableFind")->op()); |
| |
| EXPECT_EQ(1, node_lookup.count("gather1/Const")); |
| EXPECT_EQ("Const", node_lookup.at("gather1/Const")->op()); |
| Tensor expected_gather_default_tensor1(DT_FLOAT, TensorShape({})); |
| test::FillValues<float>(&expected_gather_default_tensor1, {0.0}); |
| test::ExpectTensorNear<float>( |
| expected_gather_default_tensor1, |
| GetNodeTensorAttr(*(node_lookup.at("gather1/Const")), "value"), 1e-5); |
| |
| EXPECT_EQ(1, node_lookup.count("gather1/ExpandDims/Const")); |
| EXPECT_EQ("Const", node_lookup.at("gather1/ExpandDims/Const")->op()); |
| Tensor expected_expand_dims_tensor1(DT_INT32, TensorShape({})); |
| test::FillValues<int32>(&expected_expand_dims_tensor1, {-1}); |
| test::ExpectTensorEqual<int32>( |
| expected_expand_dims_tensor1, |
| GetNodeTensorAttr(*(node_lookup.at("gather1/ExpandDims/Const")), |
| "value")); |
| |
| EXPECT_EQ(1, node_lookup.count("gather1")); |
| EXPECT_EQ("ExpandDims", node_lookup.at("gather1")->op()); |
| |
| EXPECT_EQ(1, node_lookup.count("w2/part_1/indices")); |
| EXPECT_EQ("Const", node_lookup.at("w2/part_1/indices")->op()); |
| Tensor expected_indices_tensor2(DT_INT64, TensorShape({3})); |
| test::FillValues<int64>(&expected_indices_tensor2, {0, 2, 3}); |
| test::ExpectTensorEqual<int64>( |
| expected_indices_tensor2, |
| GetNodeTensorAttr(*(node_lookup.at("w2/part_1/indices")), "value")); |
| |
| EXPECT_EQ(1, node_lookup.count("w2/part_1/values")); |
| EXPECT_EQ("Const", node_lookup.at("w2/part_1/values")->op()); |
| Tensor expected_values_tensor2(DT_FLOAT, TensorShape({3})); |
| test::FillValues<float>(&expected_values_tensor2, {0.2, 1.2, 0.001}); |
| test::ExpectTensorNear<float>( |
| expected_values_tensor2, |
| GetNodeTensorAttr(*(node_lookup.at("w2/part_1/values")), "value"), |
| 1e-5); |
| |
| EXPECT_EQ(1, node_lookup.count("w2/part_1/HashTable")); |
| EXPECT_EQ("HashTable", node_lookup.at("w2/part_1/HashTable")->op()); |
| |
| EXPECT_EQ(1, node_lookup.count("w2/part_1/InitializeTable")); |
| EXPECT_EQ("InitializeTable", |
| node_lookup.at("w2/part_1/InitializeTable")->op()); |
| |
| // Nodes in "gather2" scope. |
| EXPECT_EQ(1, node_lookup.count("gather2/LookupTableFind")); |
| EXPECT_EQ("LookupTableFind", |
| node_lookup.at("gather2/LookupTableFind")->op()); |
| |
| EXPECT_EQ(1, node_lookup.count("gather2/Const")); |
| EXPECT_EQ("Const", node_lookup.at("gather2/Const")->op()); |
| Tensor expected_gather_default_tensor2(DT_FLOAT, TensorShape({})); |
| test::FillValues<float>(&expected_gather_default_tensor2, {0.0}); |
| test::ExpectTensorNear<float>( |
| expected_gather_default_tensor2, |
| GetNodeTensorAttr(*(node_lookup.at("gather2/Const")), "value"), 1e-5); |
| |
| EXPECT_EQ(1, node_lookup.count("gather2/ExpandDims/Const")); |
| EXPECT_EQ("Const", node_lookup.at("gather2/ExpandDims/Const")->op()); |
| Tensor expected_expand_dims_tensor2(DT_INT32, TensorShape({})); |
| test::FillValues<int32>(&expected_expand_dims_tensor2, {-1}); |
| test::ExpectTensorEqual<int32>( |
| expected_expand_dims_tensor2, |
| GetNodeTensorAttr(*(node_lookup.at("gather2/ExpandDims/Const")), |
| "value")); |
| |
| EXPECT_EQ(1, node_lookup.count("gather2")); |
| EXPECT_EQ("ExpandDims", node_lookup.at("gather2")->op()); |
| |
| // Check connections |
| EXPECT_EQ("w1/part_1/HashTable", |
| node_lookup.at("w1/part_1/InitializeTable")->input(0)); |
| EXPECT_EQ("w1/part_1/indices", |
| node_lookup.at("w1/part_1/InitializeTable")->input(1)); |
| EXPECT_EQ("w1/part_1/values", |
| node_lookup.at("w1/part_1/InitializeTable")->input(2)); |
| |
| EXPECT_EQ("w2/part_1/HashTable", |
| node_lookup.at("w2/part_1/InitializeTable")->input(0)); |
| EXPECT_EQ("w2/part_1/indices", |
| node_lookup.at("w2/part_1/InitializeTable")->input(1)); |
| EXPECT_EQ("w2/part_1/values", |
| node_lookup.at("w2/part_1/InitializeTable")->input(2)); |
| |
| EXPECT_EQ("w1/part_1/HashTable", |
| node_lookup.at("gather1/LookupTableFind")->input(0)); |
| EXPECT_EQ("ids", node_lookup.at("gather1/LookupTableFind")->input(1)); |
| EXPECT_EQ("gather1/Const", |
| node_lookup.at("gather1/LookupTableFind")->input(2)); |
| EXPECT_EQ("gather1/LookupTableFind", node_lookup.at("gather1")->input(0)); |
| |
| EXPECT_EQ("w2/part_1/HashTable", |
| node_lookup.at("gather2/LookupTableFind")->input(0)); |
| EXPECT_EQ("ids", node_lookup.at("gather2/LookupTableFind")->input(1)); |
| EXPECT_EQ("gather2/Const", |
| node_lookup.at("gather2/LookupTableFind")->input(2)); |
| EXPECT_EQ("gather2/LookupTableFind", node_lookup.at("gather2")->input(0)); |
| |
| EXPECT_EQ(0, node_lookup.count("linear/concat/axis")); |
| EXPECT_EQ(0, node_lookup.count("concat/node")); |
| |
| // Check control deps. |
| EXPECT_EQ(2, node_lookup.at(shared_init_name)->input_size()); |
| EXPECT_NE(std::find(node_lookup.at(shared_init_name)->input().begin(), |
| node_lookup.at(shared_init_name)->input().end(), |
| "^w1/part_1/InitializeTable"), |
| node_lookup.at(shared_init_name)->input().end()); |
| |
| EXPECT_NE(std::find(node_lookup.at(shared_init_name)->input().begin(), |
| node_lookup.at(shared_init_name)->input().end(), |
| "^w2/part_1/InitializeTable"), |
| node_lookup.at(shared_init_name)->input().end()); |
| } |
| void TestReadTensorSlice() { |
| const auto checkpoint_path = |
| io::JoinPath(testing::TmpDir(), "checkpoint_slice"); |
| |
| Tensor weights(DT_FLOAT, TensorShape({2, 1})); |
| test::FillValues<float>(&weights, {0.2, 0.000001}); |
| BundleWriter writer(Env::Default(), checkpoint_path); |
| TF_ASSERT_OK(writer.AddSlice("w", TensorShape({4, 1}), |
| TensorSlice::ParseOrDie("0,2:0,1"), weights)); |
| TF_ASSERT_OK(writer.Finish()); |
| |
| std::unique_ptr<BundleReader> reader( |
| new BundleReader(Env::Default(), checkpoint_path)); |
| |
| Tensor results; |
| TF_ASSERT_OK( |
| ReadTensorFromCheckpoint("w/part_0", reader, "4 1 0,2:0,1", &results)); |
| |
| test::ExpectTensorEqual<float>(weights, results); |
| } |
| }; |
| |
| TEST_F(SparsifyGatherTest, TestSinglePartition) { |
| TestSinglePartition(false, false, false, false); |
| TestSinglePartition(false, true, false, false); |
| TestSinglePartition(true, false, false, false); |
| TestSinglePartition(true, true, false, false); |
| TestSinglePartition(false, false, true, false); |
| TestSinglePartition(false, true, true, false); |
| TestSinglePartition(true, false, true, false); |
| TestSinglePartition(true, true, true, false); |
| TestSinglePartition(false, true, false, false, "shared_inits"); |
| TestSinglePartition(true, true, false, false, "shared_inits"); |
| TestSinglePartition(false, true, true, false, "shared_inits"); |
| TestSinglePartition(true, true, true, false, "shared_inits"); |
| |
| TestSinglePartition(false, false, false, true); |
| TestSinglePartition(false, true, false, true); |
| TestSinglePartition(true, false, false, true); |
| TestSinglePartition(true, true, false, true); |
| TestSinglePartition(false, false, true, true); |
| TestSinglePartition(false, true, true, true); |
| TestSinglePartition(true, false, true, true); |
| TestSinglePartition(true, true, true, true); |
| TestSinglePartition(false, true, false, true, "shared_inits"); |
| TestSinglePartition(true, true, false, true, "shared_inits"); |
| TestSinglePartition(false, true, true, true, "shared_inits"); |
| TestSinglePartition(true, true, true, true, "shared_inits"); |
| } |
| |
| TEST_F(SparsifyGatherTest, TestMultiPartition) { |
| TestMultiPartition(false, false, false); |
| TestMultiPartition(false, true, false); |
| TestMultiPartition(true, false, false); |
| TestMultiPartition(true, true, false); |
| TestMultiPartition(false, false, true); |
| TestMultiPartition(false, true, true); |
| TestMultiPartition(true, false, true); |
| TestMultiPartition(true, true, true); |
| TestMultiPartition(false, true, false, "shared_inits"); |
| TestMultiPartition(true, true, false, "shared_inits"); |
| TestMultiPartition(false, true, true, "shared_inits"); |
| TestMultiPartition(true, true, true, "shared_inits"); |
| } |
| |
| TEST_F(SparsifyGatherTest, TestTensorSlice) { TestReadTensorSlice(); } |
| |
| } // namespace graph_transforms |
| } // namespace tensorflow |