blob: 1a925caab96b4b07eb6e05e525c34a9e646f4614 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the key functions in pruning library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import training_util
class PruningHParamsTest(test.TestCase):
PARAM_LIST = [
"name=test", "threshold_decay=0.9", "pruning_frequency=10",
"sparsity_function_end_step=100", "target_sparsity=0.9",
"weight_sparsity_map=[conv1:0.8,conv2/kernel:0.8]",
"block_dims_map=[dense1:4x4,dense2:1x4]"
]
TEST_HPARAMS = ",".join(PARAM_LIST)
def setUp(self):
super(PruningHParamsTest, self).setUp()
# Add global step variable to the graph
self.global_step = training_util.get_or_create_global_step()
# Add sparsity
self.sparsity = variables.VariableV1(0.5, name="sparsity")
# Parse hparams
self.pruning_hparams = pruning.get_pruning_hparams().parse(
self.TEST_HPARAMS)
def testInit(self):
p = pruning.Pruning(self.pruning_hparams)
self.assertEqual(p._spec.name, "test")
self.assertAlmostEqual(p._spec.threshold_decay, 0.9)
self.assertEqual(p._spec.pruning_frequency, 10)
self.assertEqual(p._spec.sparsity_function_end_step, 100)
self.assertAlmostEqual(p._spec.target_sparsity, 0.9)
def testInitWithExternalSparsity(self):
with self.cached_session():
p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
variables.global_variables_initializer().run()
sparsity = p._sparsity.eval()
self.assertAlmostEqual(sparsity, 0.5)
def testInitWithVariableReuse(self):
with self.cached_session():
p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
p_copy = pruning.Pruning(
spec=self.pruning_hparams, sparsity=self.sparsity)
variables.global_variables_initializer().run()
sparsity = p._sparsity.eval()
self.assertAlmostEqual(sparsity, 0.5)
self.assertEqual(p._sparsity.eval(), p_copy._sparsity.eval())
class PruningTest(test.TestCase):
def setUp(self):
super(PruningTest, self).setUp()
self.global_step = training_util.get_or_create_global_step()
def testCreateMask2D(self):
width = 10
height = 20
with self.cached_session():
weights = variables.VariableV1(
random_ops.random_normal([width, height], stddev=1), name="weights")
masked_weights = pruning.apply_mask(weights,
variable_scope.get_variable_scope())
variables.global_variables_initializer().run()
weights_val = weights.eval()
masked_weights_val = masked_weights.eval()
self.assertAllEqual(weights_val, masked_weights_val)
def testUpdateSingleMask(self):
with self.cached_session() as session:
weights = variables.VariableV1(
math_ops.linspace(1.0, 100.0, 100), name="weights")
masked_weights = pruning.apply_mask(weights)
sparsity = variables.VariableV1(0.95, name="sparsity")
p = pruning.Pruning(sparsity=sparsity)
p._spec.threshold_decay = 0.0
mask_update_op = p.mask_update_op()
variables.global_variables_initializer().run()
masked_weights_val = masked_weights.eval()
self.assertAllEqual(np.count_nonzero(masked_weights_val), 100)
session.run(mask_update_op)
masked_weights_val = masked_weights.eval()
self.assertAllEqual(np.count_nonzero(masked_weights_val), 5)
def _blockMasking(self, hparams, weights, expected_mask):
threshold = variables.VariableV1(0.0, name="threshold")
sparsity = variables.VariableV1(0.5, name="sparsity")
test_spec = ",".join(hparams)
pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
# Set up pruning
p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
with self.cached_session():
variables.global_variables_initializer().run()
_, new_mask = p._maybe_update_block_mask(weights, threshold)
# Check if the mask is the same size as the weights
self.assertAllEqual(new_mask.get_shape(), weights.get_shape())
mask_val = new_mask.eval()
self.assertAllEqual(mask_val, expected_mask)
def testBlockMaskingWithNonnegativeBlockDimensions(self):
param_list = ["block_height=2", "block_width=2", "threshold_decay=0"]
weights_avg = constant_op.constant(
[[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4],
[0.3, 0.3, 0.4, 0.4]])
weights_max = constant_op.constant(
[[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0],
[0.0, -0.3, 0.0, -0.4]])
expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
[1., 1., 1., 1.], [1., 1., 1., 1.]]
self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max,
expected_mask)
self._blockMasking(param_list + ["block_pooling_function=AVG"], weights_avg,
expected_mask)
def testBlockMaskingWithNegativeBlockDimensions(self):
param_list = ["block_height=1", "block_width=-1", "threshold_decay=0"]
weights_avg = constant_op.constant([[0.1, 0.1, 0.1, 0.1],
[0.2, 0.2, 0.2, 0.2],
[0.3, 0.3, 0.3, 0.3],
[0.3, 0.3, 0.4, 0.4]])
weights_max = constant_op.constant([[0.1, 0.0, 0.1, 0.0],
[0.0, 0.1, 0.0, 0.2],
[0.3, 0.0, 0.3, 0.0],
[0.0, -0.3, 0.0, 0.4]])
expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
[1., 1., 1., 1.], [1., 1., 1., 1.]]
self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max,
expected_mask)
self._blockMasking(param_list + ["block_pooling_function=AVG"], weights_avg,
expected_mask)
def testBlockMaskingWithHigherDimensions(self):
param_list = ["block_height=2", "block_width=2", "threshold_decay=0"]
# Weights as in testBlockMasking, but with one extra dimension.
weights_avg = constant_op.constant(
[[[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4],
[0.3, 0.3, 0.4, 0.4]]])
weights_max = constant_op.constant(
[[[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0],
[0.0, -0.3, 0.0, -0.4]]])
expected_mask = [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
[1., 1., 1., 1.], [1., 1., 1., 1.]]]
self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max,
expected_mask)
self._blockMasking(param_list + ["block_pooling_function=AVG"],
weights_avg, expected_mask)
def testPartitionedVariableMasking(self):
partitioner = partitioned_variables.variable_axis_size_partitioner(40)
with self.cached_session() as session:
with variable_scope.variable_scope("", partitioner=partitioner):
sparsity = variables.VariableV1(0.5, name="Sparsity")
weights = variable_scope.get_variable(
"weights", initializer=math_ops.linspace(1.0, 100.0, 100))
masked_weights = pruning.apply_mask(
weights, scope=variable_scope.get_variable_scope())
p = pruning.Pruning(sparsity=sparsity)
p._spec.threshold_decay = 0.0
mask_update_op = p.mask_update_op()
variables.global_variables_initializer().run()
masked_weights_val = masked_weights.eval()
session.run(mask_update_op)
masked_weights_val = masked_weights.eval()
self.assertAllEqual(np.count_nonzero(masked_weights_val), 50)
def testConditionalMaskUpdate(self):
param_list = [
"pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6",
"nbins=100"
]
test_spec = ",".join(param_list)
pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
weights = variables.VariableV1(
math_ops.linspace(1.0, 100.0, 100), name="weights")
masked_weights = pruning.apply_mask(weights)
sparsity = variables.VariableV1(0.00, name="sparsity")
# Set up pruning
p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
p._spec.threshold_decay = 0.0
mask_update_op = p.conditional_mask_update_op()
sparsity_val = math_ops.linspace(0.0, 0.9, 10)
increment_global_step = state_ops.assign_add(self.global_step, 1)
non_zero_count = []
with self.cached_session() as session:
variables.global_variables_initializer().run()
for i in range(10):
session.run(state_ops.assign(sparsity, sparsity_val[i]))
session.run(mask_update_op)
session.run(increment_global_step)
non_zero_count.append(np.count_nonzero(masked_weights.eval()))
# Weights pruned at steps 0,2,4,and,6
expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40]
self.assertAllEqual(expected_non_zero_count, non_zero_count)
def testWeightSpecificSparsity(self):
param_list = [
"begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100",
"target_sparsity=0.5",
"weight_sparsity_map=[layer1:0.6,layer2/weights:0.75,.*kernel:0.6]",
"threshold_decay=0.0"
]
test_spec = ",".join(param_list)
pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
with variable_scope.variable_scope("layer1"):
w1 = variables.VariableV1(
math_ops.linspace(1.0, 100.0, 100), name="weights")
_ = pruning.apply_mask(w1)
with variable_scope.variable_scope("layer2"):
w2 = variables.VariableV1(
math_ops.linspace(1.0, 100.0, 100), name="weights")
_ = pruning.apply_mask(w2)
with variable_scope.variable_scope("layer3"):
w3 = variables.VariableV1(
math_ops.linspace(1.0, 100.0, 100), name="kernel")
_ = pruning.apply_mask(w3)
p = pruning.Pruning(pruning_hparams)
mask_update_op = p.conditional_mask_update_op()
increment_global_step = state_ops.assign_add(self.global_step, 1)
with self.cached_session() as session:
variables.global_variables_initializer().run()
for _ in range(110):
session.run(mask_update_op)
session.run(increment_global_step)
self.assertAllClose(
session.run(pruning.get_weight_sparsity()), [0.6, 0.75, 0.6])
def testPerLayerBlockSparsity(self):
param_list = [
"block_dims_map=[layer1/weights:1x1,layer2/weights:1x2]",
"block_pooling_function=AVG", "threshold_decay=0.0"
]
test_spec = ",".join(param_list)
pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
with variable_scope.variable_scope("layer1"):
w1 = constant_op.constant([[-0.1, 0.1], [-0.2, 0.2]], name="weights")
pruning.apply_mask(w1)
with variable_scope.variable_scope("layer2"):
w2 = constant_op.constant([[0.1, 0.1, 0.3, 0.3], [0.2, 0.2, 0.4, 0.4]],
name="weights")
pruning.apply_mask(w2)
sparsity = variables.VariableV1(0.5, name="sparsity")
p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
mask_update_op = p.mask_update_op()
with self.cached_session() as session:
variables.global_variables_initializer().run()
session.run(mask_update_op)
mask1_eval = session.run(pruning.get_masks()[0])
mask2_eval = session.run(pruning.get_masks()[1])
self.assertAllEqual(
session.run(pruning.get_weight_sparsity()), [0.5, 0.5])
self.assertAllEqual(mask1_eval, [[0.0, 0.0], [1., 1.]])
self.assertAllEqual(mask2_eval, [[0, 0, 1., 1.], [0, 0, 1., 1.]])
if __name__ == "__main__":
test.main()