| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for the key functions in pruning library.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import numpy as np |
| |
| from tensorflow.contrib.model_pruning.python import pruning |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import partitioned_variables |
| from tensorflow.python.ops import random_ops |
| from tensorflow.python.ops import state_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variables |
| from tensorflow.python.platform import test |
| from tensorflow.python.training import training_util |
| |
| |
| class PruningHParamsTest(test.TestCase): |
| PARAM_LIST = [ |
| "name=test", "threshold_decay=0.9", "pruning_frequency=10", |
| "sparsity_function_end_step=100", "target_sparsity=0.9", |
| "weight_sparsity_map=[conv1:0.8,conv2/kernel:0.8]", |
| "block_dims_map=[dense1:4x4,dense2:1x4]" |
| ] |
| TEST_HPARAMS = ",".join(PARAM_LIST) |
| |
| def setUp(self): |
| super(PruningHParamsTest, self).setUp() |
| # Add global step variable to the graph |
| self.global_step = training_util.get_or_create_global_step() |
| # Add sparsity |
| self.sparsity = variables.VariableV1(0.5, name="sparsity") |
| # Parse hparams |
| self.pruning_hparams = pruning.get_pruning_hparams().parse( |
| self.TEST_HPARAMS) |
| |
| def testInit(self): |
| p = pruning.Pruning(self.pruning_hparams) |
| self.assertEqual(p._spec.name, "test") |
| self.assertAlmostEqual(p._spec.threshold_decay, 0.9) |
| self.assertEqual(p._spec.pruning_frequency, 10) |
| self.assertEqual(p._spec.sparsity_function_end_step, 100) |
| self.assertAlmostEqual(p._spec.target_sparsity, 0.9) |
| |
| def testInitWithExternalSparsity(self): |
| with self.cached_session(): |
| p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity) |
| variables.global_variables_initializer().run() |
| sparsity = p._sparsity.eval() |
| self.assertAlmostEqual(sparsity, 0.5) |
| |
| def testInitWithVariableReuse(self): |
| with self.cached_session(): |
| p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity) |
| p_copy = pruning.Pruning( |
| spec=self.pruning_hparams, sparsity=self.sparsity) |
| variables.global_variables_initializer().run() |
| sparsity = p._sparsity.eval() |
| self.assertAlmostEqual(sparsity, 0.5) |
| self.assertEqual(p._sparsity.eval(), p_copy._sparsity.eval()) |
| |
| |
| class PruningTest(test.TestCase): |
| |
| def setUp(self): |
| super(PruningTest, self).setUp() |
| self.global_step = training_util.get_or_create_global_step() |
| |
| def testCreateMask2D(self): |
| width = 10 |
| height = 20 |
| with self.cached_session(): |
| weights = variables.VariableV1( |
| random_ops.random_normal([width, height], stddev=1), name="weights") |
| masked_weights = pruning.apply_mask(weights, |
| variable_scope.get_variable_scope()) |
| variables.global_variables_initializer().run() |
| weights_val = weights.eval() |
| masked_weights_val = masked_weights.eval() |
| self.assertAllEqual(weights_val, masked_weights_val) |
| |
| def testUpdateSingleMask(self): |
| with self.cached_session() as session: |
| weights = variables.VariableV1( |
| math_ops.linspace(1.0, 100.0, 100), name="weights") |
| masked_weights = pruning.apply_mask(weights) |
| sparsity = variables.VariableV1(0.95, name="sparsity") |
| p = pruning.Pruning(sparsity=sparsity) |
| p._spec.threshold_decay = 0.0 |
| mask_update_op = p.mask_update_op() |
| variables.global_variables_initializer().run() |
| masked_weights_val = masked_weights.eval() |
| self.assertAllEqual(np.count_nonzero(masked_weights_val), 100) |
| session.run(mask_update_op) |
| masked_weights_val = masked_weights.eval() |
| self.assertAllEqual(np.count_nonzero(masked_weights_val), 5) |
| |
| def _blockMasking(self, hparams, weights, expected_mask): |
| |
| threshold = variables.VariableV1(0.0, name="threshold") |
| sparsity = variables.VariableV1(0.5, name="sparsity") |
| test_spec = ",".join(hparams) |
| pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) |
| |
| # Set up pruning |
| p = pruning.Pruning(pruning_hparams, sparsity=sparsity) |
| with self.cached_session(): |
| variables.global_variables_initializer().run() |
| _, new_mask = p._maybe_update_block_mask(weights, threshold) |
| # Check if the mask is the same size as the weights |
| self.assertAllEqual(new_mask.get_shape(), weights.get_shape()) |
| mask_val = new_mask.eval() |
| self.assertAllEqual(mask_val, expected_mask) |
| |
| def testBlockMaskingWithNonnegativeBlockDimensions(self): |
| param_list = ["block_height=2", "block_width=2", "threshold_decay=0"] |
| |
| weights_avg = constant_op.constant( |
| [[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4], |
| [0.3, 0.3, 0.4, 0.4]]) |
| weights_max = constant_op.constant( |
| [[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0], |
| [0.0, -0.3, 0.0, -0.4]]) |
| expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], |
| [1., 1., 1., 1.], [1., 1., 1., 1.]] |
| |
| self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max, |
| expected_mask) |
| self._blockMasking(param_list + ["block_pooling_function=AVG"], weights_avg, |
| expected_mask) |
| |
| def testBlockMaskingWithNegativeBlockDimensions(self): |
| param_list = ["block_height=1", "block_width=-1", "threshold_decay=0"] |
| |
| weights_avg = constant_op.constant([[0.1, 0.1, 0.1, 0.1], |
| [0.2, 0.2, 0.2, 0.2], |
| [0.3, 0.3, 0.3, 0.3], |
| [0.3, 0.3, 0.4, 0.4]]) |
| weights_max = constant_op.constant([[0.1, 0.0, 0.1, 0.0], |
| [0.0, 0.1, 0.0, 0.2], |
| [0.3, 0.0, 0.3, 0.0], |
| [0.0, -0.3, 0.0, 0.4]]) |
| expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], |
| [1., 1., 1., 1.], [1., 1., 1., 1.]] |
| |
| self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max, |
| expected_mask) |
| self._blockMasking(param_list + ["block_pooling_function=AVG"], weights_avg, |
| expected_mask) |
| |
| def testBlockMaskingWithHigherDimensions(self): |
| param_list = ["block_height=2", "block_width=2", "threshold_decay=0"] |
| |
| # Weights as in testBlockMasking, but with one extra dimension. |
| weights_avg = constant_op.constant( |
| [[[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4], |
| [0.3, 0.3, 0.4, 0.4]]]) |
| weights_max = constant_op.constant( |
| [[[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0], |
| [0.0, -0.3, 0.0, -0.4]]]) |
| expected_mask = [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], |
| [1., 1., 1., 1.], [1., 1., 1., 1.]]] |
| |
| self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max, |
| expected_mask) |
| self._blockMasking(param_list + ["block_pooling_function=AVG"], |
| weights_avg, expected_mask) |
| |
| def testPartitionedVariableMasking(self): |
| partitioner = partitioned_variables.variable_axis_size_partitioner(40) |
| with self.cached_session() as session: |
| with variable_scope.variable_scope("", partitioner=partitioner): |
| sparsity = variables.VariableV1(0.5, name="Sparsity") |
| weights = variable_scope.get_variable( |
| "weights", initializer=math_ops.linspace(1.0, 100.0, 100)) |
| masked_weights = pruning.apply_mask( |
| weights, scope=variable_scope.get_variable_scope()) |
| p = pruning.Pruning(sparsity=sparsity) |
| p._spec.threshold_decay = 0.0 |
| mask_update_op = p.mask_update_op() |
| variables.global_variables_initializer().run() |
| masked_weights_val = masked_weights.eval() |
| session.run(mask_update_op) |
| masked_weights_val = masked_weights.eval() |
| self.assertAllEqual(np.count_nonzero(masked_weights_val), 50) |
| |
| def testConditionalMaskUpdate(self): |
| param_list = [ |
| "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6", |
| "nbins=100" |
| ] |
| test_spec = ",".join(param_list) |
| pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) |
| weights = variables.VariableV1( |
| math_ops.linspace(1.0, 100.0, 100), name="weights") |
| masked_weights = pruning.apply_mask(weights) |
| sparsity = variables.VariableV1(0.00, name="sparsity") |
| # Set up pruning |
| p = pruning.Pruning(pruning_hparams, sparsity=sparsity) |
| p._spec.threshold_decay = 0.0 |
| mask_update_op = p.conditional_mask_update_op() |
| sparsity_val = math_ops.linspace(0.0, 0.9, 10) |
| increment_global_step = state_ops.assign_add(self.global_step, 1) |
| non_zero_count = [] |
| with self.cached_session() as session: |
| variables.global_variables_initializer().run() |
| for i in range(10): |
| session.run(state_ops.assign(sparsity, sparsity_val[i])) |
| session.run(mask_update_op) |
| session.run(increment_global_step) |
| non_zero_count.append(np.count_nonzero(masked_weights.eval())) |
| # Weights pruned at steps 0,2,4,and,6 |
| expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40] |
| self.assertAllEqual(expected_non_zero_count, non_zero_count) |
| |
| def testWeightSpecificSparsity(self): |
| param_list = [ |
| "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100", |
| "target_sparsity=0.5", |
| "weight_sparsity_map=[layer1:0.6,layer2/weights:0.75,.*kernel:0.6]", |
| "threshold_decay=0.0" |
| ] |
| test_spec = ",".join(param_list) |
| pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) |
| |
| with variable_scope.variable_scope("layer1"): |
| w1 = variables.VariableV1( |
| math_ops.linspace(1.0, 100.0, 100), name="weights") |
| _ = pruning.apply_mask(w1) |
| with variable_scope.variable_scope("layer2"): |
| w2 = variables.VariableV1( |
| math_ops.linspace(1.0, 100.0, 100), name="weights") |
| _ = pruning.apply_mask(w2) |
| with variable_scope.variable_scope("layer3"): |
| w3 = variables.VariableV1( |
| math_ops.linspace(1.0, 100.0, 100), name="kernel") |
| _ = pruning.apply_mask(w3) |
| |
| p = pruning.Pruning(pruning_hparams) |
| mask_update_op = p.conditional_mask_update_op() |
| increment_global_step = state_ops.assign_add(self.global_step, 1) |
| |
| with self.cached_session() as session: |
| variables.global_variables_initializer().run() |
| for _ in range(110): |
| session.run(mask_update_op) |
| session.run(increment_global_step) |
| |
| self.assertAllClose( |
| session.run(pruning.get_weight_sparsity()), [0.6, 0.75, 0.6]) |
| |
| def testPerLayerBlockSparsity(self): |
| param_list = [ |
| "block_dims_map=[layer1/weights:1x1,layer2/weights:1x2]", |
| "block_pooling_function=AVG", "threshold_decay=0.0" |
| ] |
| |
| test_spec = ",".join(param_list) |
| pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) |
| |
| with variable_scope.variable_scope("layer1"): |
| w1 = constant_op.constant([[-0.1, 0.1], [-0.2, 0.2]], name="weights") |
| pruning.apply_mask(w1) |
| |
| with variable_scope.variable_scope("layer2"): |
| w2 = constant_op.constant([[0.1, 0.1, 0.3, 0.3], [0.2, 0.2, 0.4, 0.4]], |
| name="weights") |
| pruning.apply_mask(w2) |
| |
| sparsity = variables.VariableV1(0.5, name="sparsity") |
| |
| p = pruning.Pruning(pruning_hparams, sparsity=sparsity) |
| mask_update_op = p.mask_update_op() |
| with self.cached_session() as session: |
| variables.global_variables_initializer().run() |
| session.run(mask_update_op) |
| mask1_eval = session.run(pruning.get_masks()[0]) |
| mask2_eval = session.run(pruning.get_masks()[1]) |
| |
| self.assertAllEqual( |
| session.run(pruning.get_weight_sparsity()), [0.5, 0.5]) |
| |
| self.assertAllEqual(mask1_eval, [[0.0, 0.0], [1., 1.]]) |
| self.assertAllEqual(mask2_eval, [[0, 0, 1., 1.], [0, 0, 1., 1.]]) |
| |
| |
| if __name__ == "__main__": |
| test.main() |