blob: 388384a492faaefa348ce25842ac5f73e607c11b [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions to add support for magnitude-based model pruning.
# Adds variables and ops to the graph to enable
# elementwise masking of weights
apply_mask(weights)
# Returns a list containing the sparsity of each of the weight tensors
get_weight_sparsity()
# Returns a list of all the masked weight tensorflow variables
get_masked_weights()
# Returns a list of all the mask tensorflow variables
get_masks()
# Returns a list of all the thresholds
get_thresholds()
# Returns a list of all the weight tensors that have been masked
get_weights()
The Pruning class uses a tf.hparams object to set up the
parameters for a model pruning. Here's a typical usage:
# Parse pruning hyperparameters
pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
# Create a pruning object using the pruning_hparams
p = pruning.Pruning(pruning_hparams)
# Add mask update ops to the graph
mask_update_op = p.conditional_mask_update_op()
# Add the summaries
p.add_pruning_summaries()
# Run the op
session.run(mask_update_op)
# An object of the pruning also accepts externally defined sparsity:
sparsity = tf.Variable(0.5, name = "ConstantSparsity")
p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from tensorflow.contrib.model_pruning.python import pruning_utils
from tensorflow.contrib.model_pruning.python.layers import core_layers as core
from tensorflow.contrib.training.python.training import hparam
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import training_util
_MASK_COLLECTION = core.MASK_COLLECTION
_THRESHOLD_COLLECTION = core.THRESHOLD_COLLECTION
_MASKED_WEIGHT_COLLECTION = core.MASKED_WEIGHT_COLLECTION
_WEIGHT_COLLECTION = core.WEIGHT_COLLECTION
_MASKED_WEIGHT_NAME = core.MASKED_WEIGHT_NAME
def apply_mask(x, scope=''):
"""Apply mask to a given weight tensor.
Args:
x: Input weight tensor
scope: The current variable scope. Defaults to "".
Returns:
Tensor representing masked_weights
"""
mask = pruning_utils.weight_mask_variable(x, scope)
threshold = pruning_utils.weight_threshold_variable(x, scope)
# Add masked_weights in the weights namescope so as to make it easier
# for the quantization library to add quant ops.
masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME)
# Make sure the mask for a given variable are not added multiple times to the
# collection. This is particularly important when applying mask to RNN's
# weight variables
if mask not in ops.get_collection_ref(_MASK_COLLECTION):
ops.add_to_collection(_THRESHOLD_COLLECTION, threshold)
ops.add_to_collection(_MASK_COLLECTION, mask)
ops.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights)
ops.add_to_collection(_WEIGHT_COLLECTION, x)
return masked_weights
def get_masked_weights():
return ops.get_collection(_MASKED_WEIGHT_COLLECTION)
def get_masks():
return ops.get_collection(_MASK_COLLECTION)
def get_thresholds():
return ops.get_collection(_THRESHOLD_COLLECTION)
def get_weights():
return ops.get_collection(_WEIGHT_COLLECTION)
def get_weight_sparsity():
"""Get sparsity of the weights.
Args:
None
Returns:
A list containing the sparsity of each of the weight tensors
"""
masks = get_masks()
return [nn_impl.zero_fraction(mask) for mask in masks]
def get_pruning_hparams():
"""Get a tf.HParams object with the default values for the hyperparameters.
name: string
name of the pruning specification. Used for adding summaries and ops under
a common tensorflow name_scope
begin_pruning_step: integer
the global step at which to begin pruning
end_pruning_step: integer
the global step at which to terminate pruning. Defaults to -1 implying
that pruning continues till the training stops
weight_sparsity_map: list of strings
comma separed list of {weight_variable_name:target sparsity} or
{regex:target sparsity} pairs.
For layers/weights not in this list, sparsity as specified by the
target_sparsity hyperparameter is used.
Eg. [conv1:0.9,conv2/kernel:0.8]
block_dims_map: list of strings
comma separated list of {weight variable name:block_height x block_width}
or {regex:block_height x block_width} pairs. For layers/weights not in
this list, block dims are specified by the block_height, block_width
hyperparameters are used Eg. [dense1:4x4,dense2:1x16,dense3:1x1]
threshold_decay: float
the decay factor to use for exponential decay of the thresholds
pruning_frequency: integer
How often should the masks be updated? (in # of global_steps)
nbins: integer
number of bins to use for histogram computation
block_height: integer
number of rows in a block (defaults to 1)
block_width: integer
number of cols in a block (defaults to 1)
block_pooling_function: string
Whether to perform average (AVG) or max (MAX) pooling in the block
(default: AVG)
initial_sparsity: float
initial sparsity value
target_sparsity: float
target sparsity value
sparsity_function_begin_step: integer
the global step at this which the gradual sparsity function begins to
take effect
sparsity_function_end_step: integer
the global step used as the end point for the gradual sparsity function
sparsity_function_exponent: float
exponent = 1 is linearly varying sparsity between initial and final.
exponent > 1 varies more slowly towards the end than the beginning
use_tpu: False
Indicates whether to use TPU
We use the following sparsity function:
num_steps = (sparsity_function_end_step -
sparsity_function_begin_step)/pruning_frequency
sparsity(step) = (initial_sparsity - target_sparsity)*
[1-step/(num_steps -1)]**exponent + target_sparsity
Args:
None
Returns:
tf.HParams object initialized to default values
"""
return hparam.HParams(
name='model_pruning',
begin_pruning_step=0,
end_pruning_step=-1,
weight_sparsity_map=[''],
block_dims_map=[''],
threshold_decay=0.0,
pruning_frequency=10,
nbins=256,
block_height=1,
block_width=1,
block_pooling_function='AVG',
initial_sparsity=0.0,
target_sparsity=0.5,
sparsity_function_begin_step=0,
sparsity_function_end_step=100,
sparsity_function_exponent=3.0,
use_tpu=False)
class Pruning(object):
def __init__(self, spec=None, global_step=None, sparsity=None):
"""Set up the specification for model pruning.
If a spec is provided, the sparsity is set up based on the sparsity_function
in the spec. The effect of sparsity_function is overridden if the sparsity
variable is passed to the constructor. This enables setting up arbitrary
sparsity profiles externally and passing it to this pruning functions.
Args:
spec: Pruning spec as defined in pruning.proto
global_step: A tensorflow variable that is used while setting up the
sparsity function
sparsity: A tensorflow scalar variable storing the sparsity
"""
# Pruning specification
self._spec = spec if spec else get_pruning_hparams()
# Sanity check for pruning hparams
self._validate_spec()
# A tensorflow variable that tracks the sparsity function.
# If not provided as input, the graph must already contain the global_step
# variable before calling this constructor.
self._global_step = self._setup_global_step(global_step)
# Stores the tensorflow sparsity variable.
# Built using self._setup_sparsity() or provided externally
self._sparsity = (sparsity
if sparsity is not None else self._setup_sparsity())
# List of tensorflow assignments ops for new masks and thresholds
self._assign_ops = []
# Tensorflow variable keeping track of the last global step when the masks
# were updated
self._last_update_step = self._setup_last_update_step()
# Block dimensions
self._block_dims = [self._spec.block_height, self._spec.block_width]
# Block pooling function
self._block_pooling_function = self._spec.block_pooling_function
# Mapping of layer/weight names and block dims
self._block_dims_map = self._get_block_dims_map()
# Mapping of weight names and target sparsity
self._weight_sparsity_map = self._get_weight_sparsity_map()
def _validate_spec(self):
spec = self._spec
if spec.begin_pruning_step < 0:
raise ValueError('Illegal value for begin_pruning_step')
if spec.begin_pruning_step >= spec.end_pruning_step:
if spec.end_pruning_step != -1:
raise ValueError(
'Pruning must begin before it can end. begin_step=%d, end_step=%d.'
'Set end_pruning_step to -1 if pruning is required till training'
'stops' % (spec.begin_pruning_step, spec.end_pruning_step))
if spec.sparsity_function_begin_step < 0:
raise ValueError('Illegal value for sparsity_function_begin_step')
if spec.sparsity_function_begin_step >= spec.sparsity_function_end_step:
raise ValueError(
'Sparsity function requires begin_step < end_step')
if not 0.0 <= spec.threshold_decay < 1.0:
raise ValueError('threshold_decay must be in range [0,1)')
if not 0.0 <= spec.initial_sparsity < 1.0:
raise ValueError('initial_sparsity must be in range [0,1)')
if not 0.0 <= spec.target_sparsity < 1.0:
raise ValueError('target_sparsity must be in range [0,1)')
def _setup_global_step(self, global_step):
graph_global_step = global_step
if graph_global_step is None:
graph_global_step = training_util.get_global_step()
return math_ops.cast(graph_global_step, dtypes.int32)
def _setup_sparsity(self):
begin_step = self._spec.sparsity_function_begin_step
end_step = self._spec.sparsity_function_end_step
initial_sparsity = self._spec.initial_sparsity
target_sparsity = self._spec.target_sparsity
exponent = self._spec.sparsity_function_exponent
with ops.name_scope(self._spec.name):
p = math_ops.minimum(
1.0,
math_ops.maximum(
0.0,
math_ops.div(
math_ops.cast(self._global_step - begin_step, dtypes.float32),
end_step - begin_step)))
sparsity = math_ops.add(
math_ops.multiply(initial_sparsity - target_sparsity,
math_ops.pow(1 - p, exponent)),
target_sparsity,
name='sparsity')
return sparsity
def _setup_last_update_step(self):
with variable_scope.variable_scope(
self._spec.name, use_resource=self._spec.use_tpu) as scope:
try:
last_update_step = variable_scope.get_variable(
'last_mask_update_step', [],
initializer=init_ops.zeros_initializer(),
trainable=False,
dtype=dtypes.int32)
except ValueError:
scope.reuse_variables()
last_update_step = variable_scope.get_variable(
'last_mask_update_step', dtype=dtypes.int32)
return last_update_step
def _get_block_dims_map(self):
"""Returns the map of layer name: block dims."""
block_dims_map = {}
val_list = self._spec.block_dims_map
filtered_val_list = [l for l in val_list if l]
for val in filtered_val_list:
weight_name, block_dims_str = val.split(':')
block_dims_str = block_dims_str.split('x')
if len(block_dims_str) != 2:
raise ValueError('Expected 2 values for block dim for %s, got %s' %
(weight_name, block_dims_str))
block_dims = [int(block_dims_str[0]), int(block_dims_str[1])]
block_dims_map[re.compile(weight_name)] = block_dims
return block_dims_map
def _get_block_dims(self, weight_name):
"""Returns the block dims for the given layer/weight name."""
block_dims_list = [
block_dims for regexp, block_dims in self._block_dims_map.items()
if regexp.search(weight_name)
]
if not block_dims_list:
return self._block_dims
if len(block_dims_list) > 1:
raise ValueError('Multiple matches in block_dims_map for weight %s' %
weight_name)
return block_dims_list[0]
def _get_weight_sparsity_map(self):
"""Returns the map of weight_name:sparsity parsed from the hparams."""
weight_sparsity_map = {}
val_list = self._spec.weight_sparsity_map
filtered_val_list = [l for l in val_list if l]
for val in filtered_val_list:
weight_name, sparsity = val.split(':')
if float(sparsity) >= 1.0:
raise ValueError('Weight sparsity can not exceed 1.0')
weight_sparsity_map[re.compile(weight_name)] = float(sparsity)
return weight_sparsity_map
def _get_sparsity(self, weight_name):
"""Returns target sparsity for the given layer/weight name."""
target_sparsity = [
sparsity for regexp, sparsity in self._weight_sparsity_map.items()
if regexp.search(weight_name)
]
if not target_sparsity:
return self._sparsity
if len(target_sparsity) > 1:
raise ValueError(
'Multiple matches in weight_sparsity_map for weight %s' % weight_name)
# TODO(suyoggupta): This will work when initial_sparsity = 0. Generalize
# to handle other cases as well.
return math_ops.mul(
self._sparsity,
math_ops.div(target_sparsity[0], self._spec.target_sparsity))
def _update_mask(self, weights, threshold):
"""Updates the mask for a given weight tensor.
This functions first computes the cdf of the weight tensor, and estimates
the threshold value such that 'desired_sparsity' fraction of weights
have magnitude less than the threshold.
Args:
weights: The weight tensor that needs to be masked.
threshold: The current threshold value. The function will compute a new
threshold and return the exponential moving average using the current
value of threshold
Returns:
new_threshold: The new value of the threshold based on weights, and
sparsity at the current global_step
new_mask: A numpy array of the same size and shape as weights containing
0 or 1 to indicate which of the values in weights falls below
the threshold
Raises:
ValueError: if sparsity is not defined
"""
if self._sparsity is None:
raise ValueError('Sparsity variable undefined')
sparsity = self._get_sparsity(weights.op.name)
with ops.name_scope(weights.op.name + '_pruning_ops'):
abs_weights = math_ops.abs(weights)
k = math_ops.cast(
math_ops.round(
math_ops.cast(array_ops.size(abs_weights), dtypes.float32) *
(1 - sparsity)), dtypes.int32)
# Sort the entire array
values, _ = nn_ops.top_k(
array_ops.reshape(abs_weights, [-1]), k=array_ops.size(abs_weights))
# Grab the (k-1) th value
current_threshold = array_ops.gather(values, k - 1)
smoothed_threshold = math_ops.add_n([
math_ops.multiply(current_threshold, 1 - self._spec.threshold_decay),
math_ops.multiply(threshold, self._spec.threshold_decay)
])
new_mask = math_ops.cast(
math_ops.greater_equal(abs_weights, smoothed_threshold),
dtypes.float32)
return smoothed_threshold, new_mask
def _maybe_update_block_mask(self, weights, threshold):
"""Performs block-granular masking of the weights.
Block pruning occurs only if the block_height or block_width is > 1 and
if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise
pruning occurs.
Args:
weights: The weight tensor that needs to be masked.
threshold: The current threshold value. The function will compute a new
threshold and return the exponential moving average using the current
value of threshold
Returns:
new_threshold: The new value of the threshold based on weights, and
sparsity at the current global_step
new_mask: A numpy array of the same size and shape as weights containing
0 or 1 to indicate which of the values in weights falls below
the threshold
Raises:
ValueError: if block pooling function is not AVG or MAX
"""
block_dims = self._get_block_dims(weights.op.name)
squeezed_weights = array_ops.squeeze(weights)
if squeezed_weights.get_shape().ndims != 2 or block_dims == [1, 1]:
return self._update_mask(weights, threshold)
if self._block_pooling_function not in ['AVG', 'MAX']:
raise ValueError('Unknown pooling function for block sparsity: %s' %
self._block_pooling_function)
with ops.name_scope(weights.op.name + '_pruning_ops'):
abs_weights = math_ops.abs(squeezed_weights)
pool_window = block_dims
pool_fn = pruning_utils.factorized_pool
squeeze_axis = None
if not self._spec.use_tpu:
pool_fn = nn_ops.pool
abs_weights = array_ops.reshape(
abs_weights,
[1, abs_weights.get_shape()[0],
abs_weights.get_shape()[1], 1])
squeeze_axis = [0, 3]
pooled_weights = pool_fn(
abs_weights,
window_shape=pool_window,
pooling_type=self._block_pooling_function,
strides=pool_window,
padding='SAME',
name=weights.op.name + '_pooled')
if pooled_weights.get_shape().ndims != 2:
pooled_weights = array_ops.squeeze(pooled_weights, axis=squeeze_axis)
smoothed_threshold, new_mask = self._update_mask(pooled_weights,
threshold)
updated_mask = pruning_utils.expand_tensor(new_mask, block_dims)
sliced_mask = array_ops.slice(
updated_mask, [0, 0],
[squeezed_weights.get_shape()[0],
squeezed_weights.get_shape()[1]])
return smoothed_threshold, array_ops.reshape(sliced_mask,
array_ops.shape(weights))
def _get_mask_assign_ops(self):
# Make sure the assignment ops have not already been added to the list
if self._assign_ops:
raise ValueError(
'Assign op list not empty. _get_mask_assign_ops() called twice?')
masks = get_masks()
weights = get_weights()
thresholds = get_thresholds()
if len(masks) != len(thresholds):
raise ValueError(
'Number of masks %s and number of thresholds %s mismatch' %
(len(masks), len(thresholds)))
for index, mask in enumerate(masks):
threshold = thresholds[index]
weight = weights[index]
is_partitioned = isinstance(weight, variables.PartitionedVariable)
if is_partitioned:
weight = weight.as_tensor()
new_threshold, new_mask = self._maybe_update_block_mask(weight, threshold)
self._assign_ops.append(
pruning_utils.variable_assign(threshold, new_threshold))
self._assign_ops.append(
pruning_utils.partitioned_variable_assign(mask, new_mask)
if is_partitioned else pruning_utils.variable_assign(mask, new_mask))
def mask_update_op(self):
with ops.name_scope(self._spec.name):
if not self._assign_ops:
self._get_mask_assign_ops()
with ops.control_dependencies([
state_ops.assign(
self._last_update_step,
self._global_step,
name='last_mask_update_step_assign')
]):
with ops.control_dependencies(self._assign_ops):
logging.info('Updating masks.')
return control_flow_ops.no_op('mask_update')
def conditional_mask_update_op(self):
def maybe_update_masks():
with ops.name_scope(self._spec.name):
is_step_within_pruning_range = math_ops.logical_and(
math_ops.greater_equal(self._global_step,
self._spec.begin_pruning_step),
# If end_pruning_step is negative, keep pruning forever!
math_ops.logical_or(
math_ops.less_equal(self._global_step,
self._spec.end_pruning_step),
math_ops.less(self._spec.end_pruning_step, 0)))
is_pruning_step = math_ops.less_equal(
math_ops.add(self._last_update_step, self._spec.pruning_frequency),
self._global_step)
return math_ops.logical_and(is_step_within_pruning_range,
is_pruning_step)
def mask_update_op():
return self.mask_update_op()
def no_update_op():
return control_flow_ops.no_op()
return control_flow_ops.cond(maybe_update_masks(), mask_update_op,
no_update_op)
def add_pruning_summaries(self):
"""Adds summaries of weight sparsities and thresholds."""
with ops.name_scope(self._spec.name + '_summaries'):
summary.scalar('sparsity', self._sparsity)
summary.scalar('last_mask_update_step', self._last_update_step)
masks = get_masks()
thresholds = get_thresholds()
for mask, threshold in zip(masks, thresholds):
summary.scalar(mask.op.name + '/sparsity', nn_impl.zero_fraction(mask))
summary.scalar(threshold.op.name + '/threshold', threshold)
def print_hparams(self):
logging.info(self._spec.to_json())