blob: 5e63d33db85a511f33afa1f22647aba755b85cbf [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Logic to update a TensorFlow model graph with quantization operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import graph_matcher
from tensorflow.contrib.quantize.python import input_to_ops
from tensorflow.contrib.quantize.python import quant_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
# Quantizable operation types that are supported by the quantization rewrite.
_QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'}
# Activations that are supported by the quantization rewrite.
_ACTIVATION_TYPES = {'Relu', 'Relu6'}
def Quantize(graph,
is_training,
weight_bits=8,
activation_bits=8,
ema_decay=0.999,
quant_delay=None,
vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
scope=None):
"""Updates graph with quantization operations.
Currently we quantize the following tensors:
* Conv/MatMul: Quantize the weights if it matches.
* Activation: Quantize the output if it matches.
* Bypass/Post-activation Bypass: Quantize both input and output
if it matches.
Args:
graph: Graph to modify.
is_training: Whether quantizing training graph or eval graph.
weight_bits: Number of bits to use for quantizing weights.
activation_bits: Number of bits to use for quantizing activations.
ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update
quantization intervals for quantizing activations (see here about EMA:
https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
quant_delay: (Optional, default None) Int, count of global steps for which
to delay quantization. This helps weights stabilize at the start of
training.
vars_collection: (Optional) Collection where to store the variables for
quantization interval ends.
scope: The scope to be transformed. If it's not None, only the ops which
are in this scope will be transformed.
Raises:
ValueError: When quantization fails.
"""
if scope and not scope.endswith('/'):
scope += '/'
input_to_ops_map = input_to_ops.InputToOps(graph)
for layer_match in _FindLayersToQuantize(graph):
# Quantize the weights.
context = _GetContextFromOp(layer_match.layer_op)
# If `scope` is given, only quantize it if the consumer of weights
# (the layer op) is in the right scope.
_InsertQuantOp(
context,
'weights_quant',
layer_match.weight_tensor.op, [layer_match.layer_op],
is_training,
moving_avg=False,
ema_decay=ema_decay,
quant_delay=quant_delay,
narrow_range=True,
vars_collection=vars_collection,
bits=weight_bits,
consumer_scope=scope)
# Quantize the activations.
consumer_ops = input_to_ops_map.ConsumerOperations(
layer_match.activation_op)
add_context = context
if layer_match.bypass_op:
pattern_match_result = re.search(r'^(.*)/([^/]+)', context)
if pattern_match_result is not None:
add_context = pattern_match_result.group(1)
else:
add_context = ''
# If `scope` is given, only quantize it if the producer of weights
# (usually it's the layer op) is in the right scope.
_InsertQuantOp(
add_context,
'act_quant',
layer_match.activation_op,
consumer_ops,
is_training,
moving_avg=True,
ema_decay=ema_decay,
quant_delay=quant_delay,
vars_collection=vars_collection,
bits=activation_bits,
init_min=0.0,
producer_scope=scope)
# Quantize the inputs and output to the bypass (if it exists). The input to
# the bypass is the bias add, and the output is the activation.
if layer_match.bypass_op is not None:
# If `scope` is given, only quantize it if the both the producer and the
# consumer are in the right scope.
_InsertQuantOp(
context,
'conv_quant',
layer_match.bias_add_op, [layer_match.bypass_op],
is_training,
moving_avg=True,
ema_decay=ema_decay,
quant_delay=quant_delay,
vars_collection=vars_collection,
bits=activation_bits,
producer_scope=scope,
consumer_scope=scope)
# Make sure the op following this isn't an activation. In which case, we
# shouldn't quantize it, since the activation will be Fused into the
# Add at inference time.
consumers = input_to_ops_map.ConsumerOperations(layer_match.bypass_op)
if any([consumer.type in _ACTIVATION_TYPES for consumer in consumers]):
logging.info('Skipping %s, because its followed by an activation.',
layer_match.bypass_op.name)
else:
_InsertQuantOp(
add_context,
'add_quant',
layer_match.bypass_op,
input_to_ops_map.ConsumerOperations(layer_match.bypass_op),
is_training,
moving_avg=True,
ema_decay=ema_decay,
quant_delay=quant_delay,
vars_collection=vars_collection,
bits=activation_bits,
producer_scope=scope,
consumer_scope=scope)
# Quantize bypass ops that occur after the activation.
if layer_match.post_activation_bypass_op is not None:
pattern_match_result = re.search(
r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name)
if pattern_match_result is not None:
post_activation_bypass_context = pattern_match_result.group(1)
else:
post_activation_bypass_context = ''
# If `scope` is given, only quantize it if the producer is in the right
# scope.
# Make sure the op following this isn't an activation. In which case, we
# shouldn't quantize it, since the activation will be Fused into the
# Add at inference time.
consumers = input_to_ops_map.ConsumerOperations(
layer_match.post_activation_bypass_op)
if any([consumer.type in _ACTIVATION_TYPES for consumer in consumers]):
logging.info('Skipping %s, because its followed by an activation.',
layer_match.post_activation_bypass_op.name)
else:
_InsertQuantOp(
post_activation_bypass_context,
'post_activation_bypass_quant',
layer_match.post_activation_bypass_op,
consumers,
is_training,
moving_avg=True,
ema_decay=ema_decay,
quant_delay=quant_delay,
vars_collection=vars_collection,
bits=activation_bits,
producer_scope=scope)
def _FindLayersToQuantize(graph):
"""Matches layers in graph to quantize.
The following patterns get matched. Nodes surrounded by [] will be
optionally matched:
weight|folded_weight
/
conv|fc
|
[batch_to_space_nd]
|
[post_conv_correction]
|
[biasadd|folded_bias]
|
[bypass]
|
activation
|
[post_activation_bypass]
Match replacements:
If weight|folded_weight is found, FakeQuant is added afterwards.
If bypass is found, FakeQuant is added before and after.
If activation is found, FakeQuant is added afterwards.
If post_activation_bypass is found, FakeQuant is added afterwards.
Args:
graph: Graph to perform match on.
Returns:
list of _LayerMatches.
"""
input_pattern = graph_matcher.OpTypePattern('*')
weight_var_pattern = graph_matcher.OpTypePattern('Variable|VariableV2')
weight_partition_identity_pattern = graph_matcher.OpTypePattern(
'Identity', inputs=[weight_var_pattern])
weight_partition_concat_pattern = graph_matcher.OpTypePattern(
'ConcatV2', inputs=[weight_partition_identity_pattern, '*', '*'])
weight_identity_pattern = graph_matcher.OpTypePattern(
'Identity',
inputs=[
graph_matcher.OneofPattern([
weight_partition_identity_pattern,
weight_partition_concat_pattern,
weight_var_pattern,
])
])
weight_resource_var_pattern = graph_matcher.OpTypePattern('ReadVariableOp')
folded_weight_pattern = graph_matcher.OpTypePattern('Mul')
# The weights inputs to the layer operation can either be from the Variable or
# the folded weight (Mul).
layer_pattern = graph_matcher.OpTypePattern(
'|'.join(_QUANTIZABLE_TYPES),
inputs=[
input_pattern,
graph_matcher.OneofPattern([
weight_identity_pattern, weight_resource_var_pattern,
folded_weight_pattern
])
],
ordered_inputs=False)
# For atrous convolutions a BatchToSpaceND will occur after the depthwise
# convolution.
batch_to_space_pattern = graph_matcher.OpTypePattern(
'BatchToSpaceND',
inputs=[
layer_pattern,
graph_matcher.OpTypePattern('*'),
graph_matcher.OpTypePattern('*')
])
layer_output_pattern = graph_matcher.OneofPattern(
[batch_to_space_pattern, layer_pattern])
# For separable convolutions, we are looking for a conv, followed by a conv
# with no activations between the two.
sep_conv_pattern = graph_matcher.OpTypePattern(
'|'.join(_QUANTIZABLE_TYPES),
inputs=[
graph_matcher.OneofPattern([layer_output_pattern]),
graph_matcher.OpTypePattern('*')
],
ordered_inputs=False)
folded_bias_mul_pattern = graph_matcher.OpTypePattern(
'Mul',
inputs=[graph_matcher.OpTypePattern('*'), layer_output_pattern],
ordered_inputs=False)
post_layer_op_correction_pattern = graph_matcher.OpTypePattern(
'Add',
inputs=[folded_bias_mul_pattern,
graph_matcher.OpTypePattern('*')],
ordered_inputs=False)
folded_bias_add_pattern = graph_matcher.OpTypePattern(
'Add',
inputs=[
post_layer_op_correction_pattern,
graph_matcher.OpTypePattern('*')
],
ordered_inputs=False)
# batch_norms with forced updates have an Identity operation at the end.
# TODO(suharshs): Find a way to easily skip extra Identity operations. The
# current issue is that doing so can often match patterns across many layers
# incorrectly.
batch_norm_identity = graph_matcher.OpTypePattern(
'Identity', inputs=[folded_bias_add_pattern])
bias_add_pattern = graph_matcher.OpTypePattern(
'Add|BiasAdd', inputs=[layer_output_pattern, '*'], ordered_inputs=False)
# The bias can come from the bias add or the folded bias add.
bypass_pattern = graph_matcher.OpTypePattern(
'Add',
inputs=[
graph_matcher.OneofPattern(
[bias_add_pattern, folded_bias_add_pattern, batch_norm_identity]),
'*'
],
ordered_inputs=False)
# The input to the activation can come from bias add, fold bias add, the
# bypasses.
# TODO(suharshs): We should ideally skip Identity operations instead of
# treating them as activations.
activation_pattern = graph_matcher.OpTypePattern(
'|'.join(_ACTIVATION_TYPES) + '|Identity',
inputs=[
graph_matcher.OneofPattern([
bias_add_pattern,
folded_bias_add_pattern,
batch_norm_identity,
bypass_pattern,
layer_pattern,
])
])
post_activation_bypass_pattern = graph_matcher.OpTypePattern(
'Add', inputs=['*', activation_pattern], ordered_inputs=False)
# The order of the following matching blocks is very important. Since matches
# aren't guaranteed to be disjoint, we structure matches from largest to
# smallest to guarantee that the largest match always wins. Additionally, we
# ensure that we don't match layers multiple times.
layer_matches = []
# We use matched_layer_set to ensure that layers aren't matched multiple
# times.
matched_layer_set = set()
# First, we match layers that have a post activation bypass. We do this first
# to ensure we don't match only the first part of this layer, missing the
# post activation bypass node.
post_activation_bypass_layer_matcher = graph_matcher.GraphMatcher(
post_activation_bypass_pattern)
for match_result in post_activation_bypass_layer_matcher.match_graph(graph):
layer_op = match_result.get_op(layer_pattern)
weight_tensor = match_result.get_tensor(weight_identity_pattern)
if weight_tensor is None:
weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
if weight_tensor is None:
weight_tensor = match_result.get_tensor(folded_weight_pattern)
activation_op = match_result.get_op(activation_pattern)
bias_add_op = match_result.get_op(bias_add_pattern)
if bias_add_op is None:
bias_add_op = match_result.get_op(folded_bias_add_pattern)
bypass_op = match_result.get_op(bypass_pattern)
post_activation_bypass_op = match_result.get_op(
post_activation_bypass_pattern)
if layer_op not in matched_layer_set:
matched_layer_set.add(layer_op)
layer_matches.append(
_LayerMatch(layer_op, weight_tensor, activation_op, bypass_op,
post_activation_bypass_op, bias_add_op))
# Now, we match the basic layer ending at an activation. We may get duplicate
# matches from above, but we don't add them to layer_matches.
layer_matcher = graph_matcher.GraphMatcher(activation_pattern)
for match_result in layer_matcher.match_graph(graph):
layer_op = match_result.get_op(layer_pattern)
weight_tensor = match_result.get_tensor(weight_identity_pattern)
if weight_tensor is None:
weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
if weight_tensor is None:
weight_tensor = match_result.get_tensor(folded_weight_pattern)
activation_op = match_result.get_op(activation_pattern)
bias_add_op = match_result.get_op(bias_add_pattern)
if bias_add_op is None:
bias_add_op = match_result.get_op(folded_bias_add_pattern)
bypass_op = match_result.get_op(bypass_pattern)
if layer_op not in matched_layer_set:
matched_layer_set.add(layer_op)
layer_matches.append(
_LayerMatch(layer_op, weight_tensor, activation_op, bypass_op, None,
bias_add_op))
# Match the final layer, where there may not be an activation and instead
# the output of the final BiasAdd must be quantized. So we treat the BiasAdd
# as the 'activation_op' in the _LayerMatch, to ensure that it's output is
# quantized.
final_layer_matcher = graph_matcher.GraphMatcher(
graph_matcher.OneofPattern([bias_add_pattern, folded_bias_add_pattern]))
for match_result in final_layer_matcher.match_graph(graph):
layer_op = match_result.get_op(layer_pattern)
weight_tensor = match_result.get_tensor(weight_identity_pattern)
if weight_tensor is None:
weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
if weight_tensor is None:
weight_tensor = match_result.get_tensor(folded_weight_pattern)
activation_op = match_result.get_op(bias_add_pattern)
if activation_op is None:
activation_op = match_result.get_op(folded_bias_add_pattern)
if layer_op not in matched_layer_set:
matched_layer_set.add(layer_op)
layer_matches.append(
_LayerMatch(layer_op, weight_tensor, activation_op, None, None, None))
# Look for separable convolutions here
sep_conv_matcher = graph_matcher.GraphMatcher(sep_conv_pattern)
for match_result in sep_conv_matcher.match_graph(graph):
layer_op = match_result.get_op(layer_pattern)
weight_tensor = match_result.get_tensor(weight_identity_pattern)
activation_op = match_result.get_op(layer_pattern)
if layer_op not in matched_layer_set:
matched_layer_set.add(layer_op)
layer_matches.append(
_LayerMatch(layer_op, weight_tensor, activation_op, None, None, None))
return layer_matches
class _LayerMatch(object):
"""Contains all information related to a matched Layer."""
def __init__(self, layer_op, weight_tensor, activation_op, bypass_op,
post_activation_bypass_op, bias_add_op):
self._layer_op = layer_op
self._weight_tensor = weight_tensor
self._activation_op = activation_op
self._bypass_op = bypass_op
self._post_activation_bypass_op = post_activation_bypass_op
self._bias_add_op = bias_add_op
@property
def layer_op(self):
return self._layer_op
@property
def weight_tensor(self):
return self._weight_tensor
@property
def activation_op(self):
return self._activation_op
@property
def bypass_op(self):
return self._bypass_op
@property
def post_activation_bypass_op(self):
return self._post_activation_bypass_op
@property
def bias_add_op(self):
return self._bias_add_op
def _FollowedByFakeQuant(tensor):
"""Returns True if the tensor is followed by a FakeQuant."""
fake_quant_ops = set([
'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs',
'FakeQuantWithMinMaxVarsPerChannel'
])
pass_through_ops = set(['Reshape', 'Identity'])
consumers = tensor.consumers()
while consumers:
c = consumers.pop()
if c.type in fake_quant_ops:
return True
elif c.type in pass_through_ops:
for output in c.outputs:
consumers.extend(output.consumers())
return False
def _InsertQuantOp(context,
name,
producer,
consumers,
is_training,
moving_avg=True,
init_min=-6.0,
init_max=6.0,
bits=8,
ema_decay=0.999,
quant_delay=None,
vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
narrow_range=False,
producer_scope=None,
consumer_scope=None):
"""Inserts a quant op between a producer op and (multiple) consumer ops.
Args:
context: Context where producer and consumer operations are nested.
name: Name for the new quantization op within the context.
producer: Producer operation of the pairs where quantization will be
inserted.
consumers: Consumer operations of the pairs.
is_training: Whether quantizing training graph or eval graph.
moving_avg: Specifies whether to use exponential moving average or just
the last value seen.
init_min: Starting minimum value for the new quantization op.
init_max: Starting maximum value for the new quantization op.
bits: Number of bits to use for quantization, must be between 2 and 8.
ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update
quantization intervals for quantizing activations (see here about EMA:
https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
quant_delay: (Optional, default None) Int, count of global steps for which
to delay quantization. This helps weights stabilize at the start of
training.
vars_collection: (Optional) Collection where to store the variables for
quantization interval ends.
narrow_range: Whether to use the narrow quantization range
[1; 2^bits - 1] or wide range [0; 2^bits - 1].
producer_scope: The restriction of producer scope. If not None, the new op
will be inserted only when the producer is in this scope.
consumer_scope: The restriction of producer scope. If not None, the new op
will be inserted only when all the consumers are in this scope.
Raises:
ValueError: When producer operation is not directly connected to the
consumer operation.
"""
if producer_scope and not producer.name.startswith(producer_scope):
logging.info(
'_InsertQuantOp ignores context="%s" name="%s" '
'because producer "%s" is not in scope "%s"',
context, name, producer.name, producer_scope)
return
if consumer_scope:
consumers_in_scope = []
for consumer in consumers:
if consumer.name.startswith(consumer_scope):
consumers_in_scope.append(consumer)
else:
logging.info(
'_InsertQuantOp context="%s" name="%s" ignores '
'consumer "%s" because it is not in scope "%s"',
context, name, consumer.name, consumer_scope)
return
consumers = consumers_in_scope
name_prefix = _AddContextToName(context, name)
# This is needed on TPU where name_scope == 'TPUReplicate/loop', and
# name_prefix starts with 'TPUReplicate/loop/'; without dropping it
# variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
# breaks things later.
name_scope = ops.get_name_scope()
if name_scope:
name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/')
inputs = producer.outputs[0]
# Prevent ops from being quantized multiple times. Bypass ops can sometimes
# overlap between multiple matches, so we need to ensure that we don't
# add duplicate FakeQuant operations.
if _FollowedByFakeQuant(inputs):
return
if moving_avg:
quant = (
quant_ops.MovingAvgQuantize(
inputs,
init_min=init_min,
init_max=init_max,
ema_decay=ema_decay,
is_training=is_training,
num_bits=bits,
narrow_range=narrow_range,
vars_collection=vars_collection,
name_prefix=name_prefix))
else:
quant = (
quant_ops.LastValueQuantize(
inputs,
init_min=init_min,
init_max=init_max,
is_training=is_training,
num_bits=bits,
narrow_range=narrow_range,
vars_collection=vars_collection,
name_prefix=name_prefix))
if quant_delay and quant_delay > 0:
activate_quant = math_ops.greater_equal(
common.CreateOrGetQuantizationStep(),
quant_delay,
name=name_prefix + '/activate_quant')
quant = control_flow_ops.cond(
activate_quant,
lambda: quant,
lambda: inputs,
name=name_prefix + '/delayed_quant')
if consumers:
tensors_modified_count = common.RerouteTensor(
quant, inputs, can_modify=consumers)
# Some operations can have multiple output tensors going to the same
# consumer. Since consumers is a set, we need to ensure that
# tensors_modified_count is greater than or equal to the length of the set
# of consumers.
if tensors_modified_count < len(consumers):
raise ValueError('No inputs quantized for ops: [%s]' % ', '.join(
[consumer.name for consumer in consumers]))
def _GetContextFromOp(op):
"""Gets the root context name from the op name."""
context_re = re.search(r'^(.*)/([^/]+)', op.name)
if context_re:
return context_re.group(1)
return ''
def _AddContextToName(context, name):
"""Adds the context to the name if it exists."""
if not context:
return name
return context + '/' + name