tree: d5fcd17ef902b2d359e29a6c1fa62e345fe96709 [path history] [tgz]
  1. examples/
  2. python/
  3. __init__.py
  4. BUILD
  5. README.md
tensorflow/contrib/model_pruning/README.md

Model pruning: Training tensorflow models to have masked connections

This document describes the API that facilitates magnitude-based pruning of neural network's weight tensors. The API helps inject necessary tensorflow op into the training graph so the model can be pruned while it is being trained.

Table of contents

  1. Model creation
  2. Hyperparameters for pruning
  1. Adding pruning ops to the training graph
  2. Removing pruning ops from trained model
  3. Example

Model creation

The first step involves adding mask and threshold variables to the layers that need to undergo pruning. The variable mask is the same shape as the layer's weight tensor and determines which of the weights participate in the forward execution of the graph. This can be achieved by wrapping the weight tensor of the layer with the apply_mask function provided in pruning.py. For example:

conv = tf.nn.conv2d(images, pruning.apply_mask(weights), stride, padding)

This creates a convolutional layer with additional variables mask and threshold as shown below: Convolutional layer with mask andthreshold

Alternatively, the API also provides variant of tensorflow layers with these auxiliary variables built-in (see layers) . Layers currently supported:

Pruning-related hyperparameters

The pruning library allows for specification of the following hyper parameters:

HyperparameterTypeDefaultDescription
namestringmodel_pruningName of the pruning specification. Used for adding summaries and ops under a common tensorflow name_scope
begin_pruning_stepinteger0The global step at which to begin pruning
end_pruning_stepinteger-1The global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops
weight_sparsity_maplist of strings[""]list of weight variable name regex (or layer name regex):target sparsity pairs. Eg. [conv1:0.9,conv.*/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used.
block_dims_maplist of strings[""]list of weight variable name regex (or layer name regex):block_heightxblock_width pairs. Eg. [dense1:4x4,dense2:1x16,dense3:1x1]. For layers/weights not in this list, block dims are specified by the block_height, block_width hyperparameters are used.
threshold_decayfloat0.0The decay factor to use for exponential decay of the thresholds
pruning_frequencyinteger10How often should the masks be updated? (in # of global_steps)
block_heightinteger1Number of rows in a block for block sparse matrices
block_widthinteger1Number of cols in a block for block sparse matrices
block_pooling_functionstringAVGThe function to use to pool weight values in a block: average (AVG) or max (MAX)
initial_sparsityfloat0.0Initial sparsity value
target_sparsityfloat0.5Target sparsity value
sparsity_function_begin_stepinteger0The global step at this which the gradual sparsity function begins to take effect
sparsity_function_end_stepinteger100The global step used as the end point for the gradual sparsity function
sparsity_function_exponentfloat3.0exponent = 1 is linearly varying sparsity between initial and final. exponent > 1 varies more slowly towards the end than the beginning
use_tpuboolFalseTraining using TPUs?

The sparsity $$s_t$$ at global step $$t$$ is given by:

$$s_{t}=s_{f}+\left(s_{i}-s_{f}\right)\left(1-\frac{t-t_{0}}{n\Delta t}\right)^{3}$$

The interval between sparsity_function_begin_step and sparsity_function_end_step is divided into $$n$$ intervals of size equal to the pruning_frequency ($$\Delta t$$). $$s_f$$ is the target_sparsity, $$s_i$$ is the initial_sparsity, $$t_0$$ is the sparsity_function_begin_step. In this equation, the sparsity_function_exponent is set to 3.

Block Sparsity

For some hardware architectures, it may be beneficial to induce spatially correlated sparsity. To train models in which the weight tensors have block sparse structure, set block_height and block_width hyperparameters to the desired block configuration (2x2, 4x4, 4x1, 1x8, etc). Currently, block sparsity is only supported for weight tensors which can be squeezed to rank 2. The matrix is partitioned into non-overlapping blocks of size [block_height, block_dim] and the either the average or max absolute value in this block is taken as a proxy for the entire block (set by block_pooling_function hyperparameter). The convolution layer tensors are always pruned used block dimensions of [1,1].

Adding pruning ops to the training graph

The final step involves adding ops to the training graph that monitor the distribution of the layer's weight magnitudes and determine the layer threshold, such that masking all the weights below this threshold achieves the sparsity level desired for the current training step. This can be achieved as follows:

tf.app.flags.DEFINE_string(
    'pruning_hparams', '',
    """Comma separated list of pruning-related hyperparameters""")

with tf.graph.as_default():

  # Create global step variable
  global_step = tf.train.get_or_create_global_step()

  # Parse pruning hyperparameters
  pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)

  # Create a pruning object using the pruning specification
  p = pruning.Pruning(pruning_hparams, global_step=global_step)

  # Add conditional mask update op. Executing this op will update all
  # the masks in the graph if the current global step is in the range
  # [begin_pruning_step, end_pruning_step] as specified by the pruning spec
  mask_update_op = p.conditional_mask_update_op()

  # Add summaries to keep track of the sparsity in different layers during training
  p.add_pruning_summaries()

  with tf.train.MonitoredTrainingSession(...) as mon_sess:
    # Run the usual training op in the tf session
    mon_sess.run(train_op)

    # Update the masks by running the mask_update_op
    mon_sess.run(mask_update_op)

Ensure that global_step is being incremented, otherwise pruning will not work!

Removing pruning ops from the trained graph

Once the model is trained, it is necessary to remove the auxiliary variables (mask, threshold) and pruning ops added to the graph in the steps above. This can be accomplished using the strip_pruning_vars utility.

This utility generates a binary GraphDef in which the variables have been converted to constants. In particular, the threshold variables are removed from the graph and the mask variable is fused with the corresponding weight tensor to produce a masked_weight tensor. This tensor is sparse, has the same size as the weight tensor, and the sparsity is as set by the target_sparsity or the weight_sparsity_map hyperparameters above.

$ bazel build -c opt contrib/model_pruning:strip_pruning_vars
$ bazel-bin/contrib/model_pruning/strip_pruning_vars --checkpoint_dir=/path/to/checkpoints/ --output_node_names=graph_node1,graph_node2 --output_dir=/tmp --filename=pruning_stripped.pb

For now, it is assumed that the underlying hardware platform will provide mechanisms for compressing the sparse tensors and/or accelerating the sparse tensor computations.

Example: Pruning and training deep CNNs on the cifar10 dataset

Please see Advanced Convolutional Neural Networks for details on neural network architecture, setting up inputs etc. The additional changes needed to incorporate pruning are captured in the following:

  • cifar10_pruning.py creates a deep CNN with the same architecture, but adds mask and threshold variables for each of the weight tensors in the convolutional and locally-connected layers.

  • cifar10_train.py add pruning ops to the training graph as described above.

To train the pruned version of cifar10:

$ examples_dir=contrib/model_pruning/examples
$ bazel build -c opt $examples_dir/cifar10:cifar10_{train,eval}
$ bazel-bin/$examples_dir/cifar10/cifar10_train --pruning_hparams=name=cifar10_pruning,begin_pruning_step=10000,end_pruning_step=100000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000

Eval:

$ bazel-bin/$examples_dir/cifar10/cifar10_eval --run_once

Removing pruning nodes from the trained graph:

$ bazel build -c opt contrib/model_pruning:strip_pruning_vars
$ bazel-bin/contrib/model_pruning/strip_pruning_vars --checkpoint_path=/tmp/cifar10_train --output_node_names=softmax_linear/softmax_linear_2 --filename=cifar_pruned.pb

The generated GraphDef (cifar_pruned.pb) may be visualized using the import_pb_to_tensorboard utility

References

Michael Zhu and Suyog Gupta, “To prune, or not to prune: exploring the efficacy of pruning for model compression”, 2017 NIPS Workshop on Machine Learning of Phones and other Consumer Devices (https://arxiv.org/pdf/1710.01878.pdf)