| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for tensorflow.python.client.graph_util.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import numpy as np |
| |
| from tensorflow.core.framework import attr_value_pb2 |
| from tensorflow.core.framework import graph_pb2 |
| from tensorflow.core.framework import node_def_pb2 |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.core.protobuf import meta_graph_pb2 |
| from tensorflow.python import keras |
| from tensorflow.python.client import session |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import function |
| from tensorflow.python.framework import graph_util |
| from tensorflow.python.framework import importer |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.grappler import tf_optimizer |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import gen_math_ops |
| from tensorflow.python.ops import gen_state_ops |
| from tensorflow.python.ops import math_ops # pylint: disable=unused-import |
| from tensorflow.python.ops import math_ops as math_ops_lib |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variables |
| from tensorflow.python.platform import test |
| from tensorflow.python.training.saver import export_meta_graph |
| |
| |
| # Utility device function to use for testing |
| def test_device_func_pin_variable_to_cpu(op): |
| if op.device: |
| return op.device |
| return "/cpu:0" if op.node_def.op in ["Variable", "VariableV2"] else op.device |
| |
| |
| class DeviceFunctionsTest(test.TestCase): |
| |
| def testTwoDeviceFunctions(self): |
| with ops.Graph().as_default() as g: |
| var_0 = gen_state_ops.variable( |
| shape=[1], |
| dtype=dtypes.float32, |
| name="var_0", |
| container="", |
| shared_name="") |
| with g.device(test_device_func_pin_variable_to_cpu): |
| var_1 = gen_state_ops.variable( |
| shape=[1], |
| dtype=dtypes.float32, |
| name="var_1", |
| container="", |
| shared_name="") |
| var_2 = gen_state_ops.variable( |
| shape=[1], |
| dtype=dtypes.float32, |
| name="var_2", |
| container="", |
| shared_name="") |
| var_3 = gen_state_ops.variable( |
| shape=[1], |
| dtype=dtypes.float32, |
| name="var_3", |
| container="", |
| shared_name="") |
| with g.device(test_device_func_pin_variable_to_cpu): |
| var_4 = gen_state_ops.variable( |
| shape=[1], |
| dtype=dtypes.float32, |
| name="var_4", |
| container="", |
| shared_name="") |
| with g.device("/device:GPU:0"): |
| var_5 = gen_state_ops.variable( |
| shape=[1], |
| dtype=dtypes.float32, |
| name="var_5", |
| container="", |
| shared_name="") |
| var_6 = gen_state_ops.variable( |
| shape=[1], |
| dtype=dtypes.float32, |
| name="var_6", |
| container="", |
| shared_name="") |
| |
| self.assertDeviceEqual(var_0.device, None) |
| self.assertDeviceEqual(var_1.device, "/device:CPU:0") |
| self.assertDeviceEqual(var_2.device, None) |
| self.assertDeviceEqual(var_3.device, None) |
| self.assertDeviceEqual(var_4.device, "/device:CPU:0") |
| self.assertDeviceEqual(var_5.device, "/device:GPU:0") |
| self.assertDeviceEqual(var_6.device, "/device:CPU:0") |
| |
| @test_util.run_v1_only("b/120545219") |
| def testNestedDeviceFunctions(self): |
| with ops.Graph().as_default(): |
| var_0 = variables.VariableV1(0) |
| with ops.device(test_device_func_pin_variable_to_cpu): |
| var_1 = variables.VariableV1(1) |
| with ops.device(lambda op: "/device:GPU:0"): |
| var_2 = variables.VariableV1(2) |
| with ops.device("/device:GPU:0"): # Implicit merging device function. |
| var_3 = variables.VariableV1(3) |
| |
| self.assertDeviceEqual(var_0.device, None) |
| self.assertDeviceEqual(var_1.device, "/device:CPU:0") |
| self.assertDeviceEqual(var_2.device, "/device:GPU:0") |
| self.assertDeviceEqual(var_3.device, "/device:GPU:0") |
| |
| def testExplicitDevice(self): |
| with ops.Graph().as_default() as g: |
| const_0 = constant_op.constant(5.0) |
| with g.device("/device:GPU:0"): |
| const_1 = constant_op.constant(5.0) |
| with g.device("/device:GPU:1"): |
| const_2 = constant_op.constant(5.0) |
| with g.device("/device:CPU:0"): |
| const_3 = constant_op.constant(5.0) |
| with g.device("/device:CPU:1"): |
| const_4 = constant_op.constant(5.0) |
| with g.device("/job:ps"): |
| const_5 = constant_op.constant(5.0) |
| |
| self.assertDeviceEqual(const_0.device, None) |
| self.assertDeviceEqual(const_1.device, "/device:GPU:0") |
| self.assertDeviceEqual(const_2.device, "/device:GPU:1") |
| self.assertDeviceEqual(const_3.device, "/device:CPU:0") |
| self.assertDeviceEqual(const_4.device, "/device:CPU:1") |
| self.assertDeviceEqual(const_5.device, "/job:ps") |
| |
| def testDefaultDevice(self): |
| with ops.Graph().as_default() as g, g.device( |
| test_device_func_pin_variable_to_cpu): |
| with g.device("/job:ps"): |
| const_0 = constant_op.constant(5.0) |
| with g.device("/device:GPU:0"): |
| const_1 = constant_op.constant(5.0) |
| with g.device("/device:GPU:1"): |
| const_2 = constant_op.constant(5.0) |
| with g.device("/device:CPU:0"): |
| const_3 = constant_op.constant(5.0) |
| with g.device("/device:CPU:1"): |
| const_4 = constant_op.constant(5.0) |
| with g.device("/replica:0"): |
| const_5 = constant_op.constant(5.0) |
| |
| self.assertDeviceEqual(const_0.device, "/job:ps") |
| self.assertDeviceEqual(const_1.device, "/device:GPU:0") |
| self.assertDeviceEqual(const_2.device, "/device:GPU:1") |
| self.assertDeviceEqual(const_3.device, "/device:CPU:0") |
| self.assertDeviceEqual(const_4.device, "/device:CPU:1") |
| self.assertDeviceEqual(const_5.device, "/replica:0") |
| |
| def testExtractSubGraph(self): |
| graph_def = graph_pb2.GraphDef() |
| n1 = graph_def.node.add() |
| n1.name = "n1" |
| n1.input.extend(["n5"]) |
| n2 = graph_def.node.add() |
| n2.name = "n2" |
| # Take the first output of the n1 node as the input. |
| n2.input.extend(["n1:0"]) |
| n3 = graph_def.node.add() |
| n3.name = "n3" |
| # Add a control input (which isn't really needed by the kernel, but |
| # rather to enforce execution order between nodes). |
| n3.input.extend(["^n2"]) |
| n4 = graph_def.node.add() |
| n4.name = "n4" |
| |
| # It is fine to have a loops in the graph as well. |
| n5 = graph_def.node.add() |
| n5.name = "n5" |
| n5.input.extend(["n1"]) |
| |
| sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"]) |
| self.assertEqual("n1", sub_graph.node[0].name) |
| self.assertEqual("n2", sub_graph.node[1].name) |
| self.assertEqual("n3", sub_graph.node[2].name) |
| self.assertEqual("n5", sub_graph.node[3].name) |
| |
| def testExtractSubGraphWithInvalidDestNodes(self): |
| graph_def = graph_pb2.GraphDef() |
| n1 = graph_def.node.add() |
| n1.name = "n1" |
| with self.assertRaisesRegexp(TypeError, "must be a list"): |
| graph_util.extract_sub_graph(graph_def, "n1") |
| |
| def create_node_def(self, op, name, inputs): |
| new_node = node_def_pb2.NodeDef() |
| new_node.op = op |
| new_node.name = name |
| new_node.input.extend(inputs) |
| return new_node |
| |
| def create_constant_node_def(self, |
| name, |
| value, |
| dtype, |
| shape=None, |
| inputs=None): |
| node = self.create_node_def("Const", name, inputs or []) |
| self.set_attr_dtype(node, "dtype", dtype) |
| self.set_attr_tensor(node, "value", value, dtype, shape) |
| return node |
| |
| def set_attr_dtype(self, node, key, value): |
| node.attr[key].CopyFrom( |
| attr_value_pb2.AttrValue(type=value.as_datatype_enum)) |
| |
| def set_attr_tensor(self, node, key, value, dtype, shape=None): |
| node.attr[key].CopyFrom( |
| attr_value_pb2.AttrValue( |
| tensor=tensor_util.make_tensor_proto( |
| value, dtype=dtype, shape=shape))) |
| |
| def testRemoveTrainingNodes(self): |
| a_constant_name = "a_constant" |
| b_constant_name = "b_constant" |
| a_check_name = "a_check" |
| b_check_name = "b_check" |
| a_identity_name = "a_identity" |
| b_identity_name = "b_identity" |
| add_name = "add" |
| graph_def = graph_pb2.GraphDef() |
| a_constant = self.create_constant_node_def( |
| a_constant_name, value=1, dtype=dtypes.float32, shape=[]) |
| graph_def.node.extend([a_constant]) |
| a_check_node = self.create_node_def("CheckNumerics", a_check_name, |
| [a_constant_name]) |
| graph_def.node.extend([a_check_node]) |
| a_identity_node = self.create_node_def( |
| "Identity", a_identity_name, [a_constant_name, "^" + a_check_name]) |
| graph_def.node.extend([a_identity_node]) |
| b_constant = self.create_constant_node_def( |
| b_constant_name, value=1, dtype=dtypes.float32, shape=[]) |
| graph_def.node.extend([b_constant]) |
| b_check_node = self.create_node_def("CheckNumerics", b_check_name, |
| [b_constant_name]) |
| graph_def.node.extend([b_check_node]) |
| b_identity_node = self.create_node_def( |
| "Identity", b_identity_name, [b_constant_name, "^" + b_check_name]) |
| graph_def.node.extend([b_identity_node]) |
| add_node = self.create_node_def("Add", add_name, |
| [a_identity_name, b_identity_name]) |
| self.set_attr_dtype(add_node, "T", dtypes.float32) |
| graph_def.node.extend([add_node]) |
| |
| expected_output = graph_pb2.GraphDef() |
| a_constant = self.create_constant_node_def( |
| a_constant_name, value=1, dtype=dtypes.float32, shape=[]) |
| expected_output.node.extend([a_constant]) |
| b_constant = self.create_constant_node_def( |
| b_constant_name, value=1, dtype=dtypes.float32, shape=[]) |
| expected_output.node.extend([b_constant]) |
| add_node = self.create_node_def("Add", add_name, |
| [a_constant_name, b_constant_name]) |
| self.set_attr_dtype(add_node, "T", dtypes.float32) |
| expected_output.node.extend([add_node]) |
| |
| output = graph_util.remove_training_nodes(graph_def) |
| self.assertProtoEquals(expected_output, output) |
| |
| def testRemoveIdentityChains(self): |
| """Check that chains of Identity nodes are correctly pruned. |
| |
| Create a chain of four nodes, A, B, C, and D where A inputs B, B inputs C, |
| and C inputs D. Nodes B and C are "Identity" and should be pruned, resulting |
| in the nodes A and D, where A inputs D. |
| """ |
| graph_def = graph_pb2.GraphDef() |
| graph_def.node.extend([ |
| self.create_node_def("Aop", "A", ["B"]), |
| self.create_node_def("Identity", "B", ["C"]), |
| self.create_node_def("Identity", "C", ["D"]), |
| self.create_node_def("Dop", "D", []) |
| ]) |
| |
| expected_graph_def = graph_pb2.GraphDef() |
| expected_graph_def.node.extend([ |
| self.create_node_def("Aop", "A", ["D"]), |
| self.create_node_def("Dop", "D", []) |
| ]) |
| |
| self.assertProtoEquals(expected_graph_def, |
| graph_util.remove_training_nodes(graph_def)) |
| |
| def testRemoveIdentityUsedAsControlInputInConst(self): |
| """Check that Identity nodes used as control inputs are not removed.""" |
| graph_def = graph_pb2.GraphDef() |
| graph_def.node.extend([ |
| self.create_constant_node_def("C", 1, dtypes.float32, inputs=["^I"]), |
| self.create_node_def("Identity", "I", ["Base"]), |
| self.create_node_def("BaseOp", "Base", []) |
| ]) |
| |
| self.assertProtoEquals(graph_def, |
| graph_util.remove_training_nodes(graph_def)) |
| |
| |
| class ConvertVariablesToConstantsTest(test.TestCase): |
| |
| def _get_tensors(self, sess, tensor_list): |
| """Returns a list of Tensor objects from the Session.""" |
| return [ |
| sess.graph.get_tensor_by_name(tensor.name) for tensor in tensor_list |
| ] |
| |
| def _get_tensor_names(self, tensors): |
| """Returns a list of string names for the tensors specified.""" |
| return [tensor.name.split(":")[0] for tensor in tensors] |
| |
| def _evaluate_graph_def(self, graph_def, inputs, outputs, input_data): |
| """Evaluates the GraphDef using Sessions.""" |
| with ops.Graph().as_default() as graph: |
| importer.import_graph_def(graph_def, name="") |
| sess = session.Session(graph=graph) |
| |
| input_tensors = self._get_tensors(sess, inputs) |
| output_tensors = self._get_tensors(sess, outputs) |
| return sess.run( |
| output_tensors, feed_dict=dict(zip(input_tensors, input_data))) |
| |
| def _ensure_no_variables_in_graph(self, graph_def): |
| """Ensures there are no variables in the graph.""" |
| for node in graph_def.node: |
| self.assertNotIn( |
| node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"]) |
| |
| def _test_converted_keras_model(self, model, constant_graph_def, input_data): |
| """Compares the converted Keras model.""" |
| expected_value = model.predict(input_data) |
| actual_value = self._evaluate_graph_def(constant_graph_def, model.inputs, |
| model.outputs, [input_data]) |
| np.testing.assert_almost_equal(np.array([expected_value]), actual_value, 5) |
| |
| def _test_variable_to_const_conversion(self, use_resource): |
| with ops.Graph().as_default(): |
| with variable_scope.variable_scope("", use_resource=use_resource): |
| variable_node = variable_scope.get_variable( |
| "variable_node", initializer=1.0) |
| another_variable = variable_scope.get_variable( |
| "unused_variable_node", initializer=1.0) |
| output_node = math_ops_lib.multiply( |
| variable_node, 2.0, name="output_node") |
| with session.Session() as sess: |
| self.evaluate(variable_node.initializer) |
| output = self.evaluate(output_node) |
| self.assertNear(2.0, output, 0.00001) |
| variable_graph_def = sess.graph.as_graph_def() |
| # First get the constant_graph_def when variable_names_whitelist is |
| # set, note that if variable_names_whitelist is not set an error will |
| # be thrown because unused_variable_node is not initialized. |
| constant_graph_def = graph_util.convert_variables_to_constants( |
| sess, |
| variable_graph_def, ["output_node"], |
| variable_names_whitelist=set(["variable_node"])) |
| |
| # Then initialize the unused variable, and get another |
| # constant_graph_def when variable_names_whitelist is not set. |
| self.evaluate(another_variable.initializer) |
| constant_graph_def_without_variable_whitelist = ( |
| graph_util.convert_variables_to_constants( |
| sess, variable_graph_def, ["output_node"])) |
| |
| # The unused variable should be cleared so the two graphs should be |
| # equivalent. |
| self.assertEqual( |
| str(constant_graph_def), |
| str(constant_graph_def_without_variable_whitelist)) |
| |
| # Test variable name black list. This should result in the variable |
| # not being a const. |
| constant_graph_def_with_blacklist = ( |
| graph_util.convert_variables_to_constants( |
| sess, |
| variable_graph_def, ["output_node"], |
| variable_names_blacklist=set(["variable_node"]))) |
| variable_node = None |
| for node in constant_graph_def_with_blacklist.node: |
| if node.name == "variable_node": |
| variable_node = node |
| self.assertIsNotNone(variable_node) |
| if use_resource: |
| self.assertEqual(variable_node.op, "VarHandleOp") |
| else: |
| self.assertEqual(variable_node.op, "VariableV2") |
| |
| # Now we make sure the variable is now a constant, and that the graph still |
| # produces the expected result. |
| with ops.Graph().as_default(): |
| _ = importer.import_graph_def(constant_graph_def, name="") |
| self.assertEqual(4, len(constant_graph_def.node)) |
| self._ensure_no_variables_in_graph(constant_graph_def) |
| with session.Session() as sess: |
| output_node = sess.graph.get_tensor_by_name("output_node:0") |
| output = self.evaluate(output_node) |
| self.assertNear(2.0, output, 0.00001) |
| |
| def _test_convert_variables_with_functions(self, inline_functions): |
| """Freezes a graph with functions.""" |
| |
| @function.Defun(dtypes.float32) |
| def plus_one(x): |
| return x + 1.0 |
| |
| with ops.Graph().as_default(): |
| variable_node = variables.Variable(1.0, name="variable_node") |
| _ = variables.Variable(1.0, name="unused_variable_node") |
| defun_node = plus_one(variable_node) |
| _ = math_ops_lib.multiply(defun_node, 2.0, name="output_node") |
| |
| with session.Session() as sess: |
| self.evaluate(variables.variables_initializer([variable_node])) |
| variable_graph_def = sess.graph.as_graph_def() |
| |
| if inline_functions: |
| # Run Grappler to create the VarOpHandle --> Placeholder --> |
| # ResourceVariable pattern. |
| meta_graph = export_meta_graph(graph_def=variable_graph_def) |
| fetch_collection = meta_graph_pb2.CollectionDef() |
| for name in ["variable_node", "output_node"]: |
| fetch_collection.node_list.value.append(name) |
| meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) |
| |
| # Initialize RewriterConfig with everything disabled except function |
| # inlining. |
| config = config_pb2.ConfigProto() |
| rewrite_options = config.graph_options.rewrite_options |
| rewrite_options.optimizers.append("function") |
| variable_graph_def = tf_optimizer.OptimizeGraph(config, meta_graph) |
| |
| constant_graph_def = graph_util.convert_variables_to_constants( |
| sess, variable_graph_def, ["output_node"]) |
| |
| self._ensure_no_variables_in_graph(constant_graph_def) |
| |
| def testReferenceVariables(self): |
| """Freezes a graph with reference variables.""" |
| self._test_variable_to_const_conversion(use_resource=False) |
| |
| def testResourceVariables(self): |
| """Freezes a graph with resource variables.""" |
| self._test_variable_to_const_conversion(use_resource=True) |
| |
| def testWithFunctions(self): |
| """Freezes a graph with functions.""" |
| self._test_convert_variables_with_functions(inline_functions=False) |
| |
| def testWithInlinedFunctions(self): |
| """Freezes a graph with functions that have been inlined using Grappler.""" |
| self._test_convert_variables_with_functions(inline_functions=True) |
| |
| @test_util.run_v1_only("Incompatible with TF 2.0") |
| def testWithEmbeddings(self): |
| """Freezes a graph with embeddings.""" |
| state_input = keras.layers.Input( |
| shape=(1,), name="state_input", dtype="int32") |
| output = keras.layers.Embedding( |
| output_dim=16, input_dim=100, input_length=1, name="state")( |
| state_input) |
| model = keras.models.Model(inputs=[state_input], outputs=[output]) |
| model.compile( |
| loss={"state": "sparse_categorical_crossentropy"}, optimizer="adam") |
| |
| # Freeze the graph. |
| sess = keras.backend.get_session() |
| variable_graph_def = sess.graph_def |
| output_tensor = self._get_tensor_names(model.outputs) |
| constant_graph_def = graph_util.convert_variables_to_constants( |
| sess, variable_graph_def, output_tensor) |
| |
| # Validate converted graph. |
| input_data = np.array(np.random.random_sample([1, 1]), dtype=np.int32) |
| self._ensure_no_variables_in_graph(constant_graph_def) |
| self._test_converted_keras_model(model, constant_graph_def, input_data) |
| |
| def testGraphWithSwitch(self): |
| """Freezes a graph which contains a Switch with type RESOURCE_DT.""" |
| with ops.Graph().as_default(): |
| with variable_scope.variable_scope("", use_resource=True): |
| x = variable_scope.get_variable("var_x", initializer=1.0) |
| y = variable_scope.get_variable("var_y", initializer=2.0) |
| f1 = lambda: variable_scope.get_variable("var_f1", initializer=17.0) |
| f2 = lambda: variable_scope.get_variable("var_f2", initializer=23.0) |
| cond_node = control_flow_ops.case([(gen_math_ops.less(x, y), f1)], |
| default=f2) |
| _ = math_ops_lib.multiply(cond_node, 2.0, name="output_node") |
| |
| with session.Session() as sess: |
| sess.run(variables.global_variables_initializer()) |
| variable_graph_def = sess.graph.as_graph_def() |
| |
| constant_graph_def = graph_util.convert_variables_to_constants( |
| sess, variable_graph_def, ["output_node"]) |
| |
| self._ensure_no_variables_in_graph(constant_graph_def) |
| |
| @test_util.run_v1_only("Incompatible with TF 2.0") |
| def testLSTM(self): |
| """Freezes a Keras LSTM.""" |
| model = keras.models.Sequential( |
| [keras.layers.LSTM(units=10, input_shape=(10, 10))]) |
| |
| # Freeze the model. |
| sess = keras.backend.get_session() |
| variable_graph_def = sess.graph_def |
| output_tensor = self._get_tensor_names(model.outputs) |
| constant_graph_def = graph_util.convert_variables_to_constants( |
| sess, variable_graph_def, output_tensor) |
| |
| # Validate converted graph. |
| input_data = np.array(np.random.random_sample([10, 10, 10]), dtype=np.int32) |
| self._ensure_no_variables_in_graph(constant_graph_def) |
| self._test_converted_keras_model(model, constant_graph_def, input_data) |
| |
| |
| if __name__ == "__main__": |
| test.main() |