blob: ea991d40eea17e9e59d14c25f60c3184c550e007 [file] [log] [blame]
import torch
from torch.fx import (
Graph,
GraphModule,
Node,
)
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
import torch.nn.functional as F
from torch.nn.utils.fusion import fuse_conv_bn_weights
# TODO[jerryzh168]: move this to a more general util function
from torch.ao.quantization.fx.prepare import (
_is_activation_post_process_node,
)
import copy
import operator
from typing import Any, Callable, Dict, Optional, Tuple
__all__ = [
"fold_bn_weights_into_conv_node",
"get_aten_graph_module",
"remove_tensor_overload_for_qdq_ops",
]
def _get_tensor_constant_from_node(node, m):
if node is None:
return None
assert node.op == "get_attr"
return getattr(m, node.target)
def _get_all_arguments(orig_args, orig_kwargs, args_schema):
all_args = []
for i, schema in enumerate(args_schema):
if schema.name in orig_kwargs:
all_args.append(orig_kwargs[schema.name])
elif not schema.kwarg_only and i < len(orig_args):
all_args.append(orig_args[i])
else:
all_args.append(schema.default_value)
return all_args
def fold_bn_weights_into_conv_node(
conv_node: Node,
conv_weight_node: Node,
conv_bias_node: Optional[Node],
bn_node: Node,
m: GraphModule
) -> None:
# conv args: input, weight, bias, stride, padding, dilation, transposed, ...
conv_w = _get_tensor_constant_from_node(conv_weight_node, m)
conv_b = _get_tensor_constant_from_node(conv_bias_node, m)
transpose = conv_node.args[6]
# eval bn args: input, weight, bias, running mean, running var, momentum, eps
# train bn args: input, weight, bias, running mean, running var, training, momentum, eps
bn_args_schema = bn_node.target._schema.arguments # type: ignore[union-attr]
bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema)
bn_w = _get_tensor_constant_from_node(bn_args[1], m)
bn_b = _get_tensor_constant_from_node(bn_args[2], m)
bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
eps_arg_index = 6
elif bn_node.target == torch.ops.aten._native_batch_norm_legit.default:
eps_arg_index = 7
else:
raise ValueError("BN node target is unexpected ", bn_node.target)
bn_eps = bn_args[eps_arg_index]
fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose)
# update the weight and bias for conv
conv_args = list(conv_node.args)
# calling data since the fused_weight and fused_bias are nn.Parameter
weight_attr_name = conv_weight_node.target
assert isinstance(weight_attr_name, str)
setattr(m, weight_attr_name, fused_weight)
if conv_bias_node is not None:
bias_attr_name = conv_bias_node.target
else:
bias_attr_name = weight_attr_name + "_bias"
with m.graph.inserting_before(conv_node):
get_bias_node = m.graph.get_attr(bias_attr_name)
# NOTE: here we assume the bias of conv is not quantized!
conv_args[2] = get_bias_node
setattr(m, bias_attr_name, fused_bias) # type: ignore[arg-type]
conv_node.args = tuple(conv_args)
# native_batch_norm has 3 outputs, we expect getitem calls on the output
# and we want to replace the uses of getitem 0 with the output of conv
#
# Before:
# conv -> bn - (first output) -> users1
# \ - (second output) -> users2
# \ - (third output) -> users3
# After:
# conv -> (first output) -> users1
# bn -
# \ - (second output) -> users2
# \ - (third output) -> users3
# if users2 and users3 are empty then bn will be removed through dead code elimination
for user in bn_node.users:
if user.op != "call_function" or user.target != operator.getitem or user.args[1] != 0:
continue
user.replace_all_uses_with(conv_node)
# fuse conv bn weights, inplace modification of the graph_module and graph
def _fuse_conv_bn_(m: GraphModule) -> None:
for n in m.graph.nodes:
if n.op != "call_function" or n.target != torch.ops.aten._native_batch_norm_legit_no_training.default:
continue
bn_node = n
n = bn_node.args[0]
if n.op != "call_function" or n.target != torch.ops.aten.convolution.default:
continue
conv_node = n
conv_weight_node = conv_node.args[1]
conv_bias_node = conv_node.args[2]
fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m)
m.graph.eliminate_dead_code()
m.recompile()
# TODO: remove hack when we have better support for pattern matching
# move around the observer for addmm
def _rearrange_weight_observer_for_decomposed_linear(
model: GraphModule,
) -> None:
"""
Linear is decomposed to `t - addmm` (w/ bias) or `t - mm` (w/o bias)
before:
weight - t - observer \
input - observer - addmm/mm
after:
weight - observer - t \
input - observer - addmm/mm
"""
aten = torch.ops.aten
op_to_weight_obs_index = {
aten.addmm.default : 2,
aten.mm.default : 1,
}
named_modules = dict(model.named_modules(remove_duplicate=False))
for node in model.graph.nodes:
if node.target not in (aten.addmm.default, aten.mm.default):
continue
root_node = node
maybe_weight_obs = root_node.args[op_to_weight_obs_index[root_node.target]]
if not _is_activation_post_process_node(maybe_weight_obs, named_modules):
continue
transpose_node = maybe_weight_obs.args[0]
if transpose_node.target != torch.ops.aten.t.default:
continue
# swap the order of transpose and observation
maybe_weight_obs.replace_input_with(transpose_node, transpose_node.args[0])
# remove the transpose node
with model.graph.inserting_after(maybe_weight_obs):
args = list(transpose_node.args)
args[0] = maybe_weight_obs
new_transpose_node = model.graph.create_node(
"call_function",
torch.ops.aten.t.default,
tuple(args),
transpose_node.kwargs
)
root_node.replace_input_with(maybe_weight_obs, new_transpose_node)
model.graph.eliminate_dead_code()
model.graph.lint()
model.recompile()
def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]:
# TODO: move this information to fx node itself
node_name_to_scope: Dict[str, Tuple[str, type]] = {}
for n in model.graph.nodes:
nn_module_stack = n.meta.get("nn_module_stack", None)
current_scope = ("", type(None))
if nn_module_stack:
bt = list(nn_module_stack.values())[-1]
current_scope = (bt[0].split(".")[-1], bt[1])
node_name_to_scope[n.name] = current_scope
return node_name_to_scope
def get_aten_graph_module(
pattern: Callable,
example_inputs: Tuple[Any, ...],
**kwargs,
) -> GraphModule:
"""
Convert the pattern to an FX graph with decomposed aten ops.
"""
# Avoid circular imports
import torch._dynamo
aten_pattern, _ = torch._dynamo.export(
pattern,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
**kwargs,
)
aten_pattern.graph.eliminate_dead_code()
aten_pattern.recompile()
return aten_pattern
def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
""" Remove .tensor overload for quantize/dequantize ops so that we can
use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e
"""
_MAP = {
torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel,
torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel,
torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp,
}
for n in match_pattern.graph.nodes:
if n.op != "call_function":
continue
if n.target in _MAP:
n.target = _MAP[n.target]
def _is_dropout_filter(
match: "InternalMatch", # type: ignore[name-defined]
original_graph: Graph,
pattern_graph: Graph,
) -> bool:
"""
Match filter for the subgraph rewriter that returns True if the matched
graph includes all the ops used in the aten dropout pattern.
"""
ops_to_match = {
torch.ops.aten.empty_like.default,
torch.ops.aten.bernoulli_.float,
torch.ops.aten.div_.Scalar,
torch.ops.aten.mul.Tensor,
}
for n in match.nodes_map.values():
if n.target in ops_to_match:
ops_to_match.remove(n.target)
return len(ops_to_match) == 0
def _replace_dropout_for_eval(m: GraphModule):
"""
Replace the aten training dropout pattern with a noop, intended for eval.
For models with dropout torch ops (nn.Dropout, F.dropout), calling model.eval()
effectively turns these dropout ops into noops. For exported models, however,
this is not done automatically, since the aten dropout patterns previously generated
for training remain in the graph. Here we rewrite these dropout patterns with noops
to avoid incorrectly applying further dropout during eval.
See https://github.com/pytorch/pytorch/issues/103681.
"""
def dropout_train(x):
return F.dropout(x, p=0.5, training=True)
def dropout_eval(x):
return F.dropout(x, p=0.5, training=False)
example_inputs = (torch.randn(1),)
match_pattern = get_aten_graph_module(dropout_train, example_inputs)
replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
# Note: The match pattern looks like:
#
# empty_like_default = torch.ops.aten.empty_like.default(x)
# bernoulli__float = torch.ops.aten.bernoulli_.float(empty_like_default)
# div__scalar = torch.ops.aten.div_.Scalar(bernoulli__float, 0.5)
# mul_tensor = torch.ops.aten.mul.Tensor(x, div__scalar)
#
# We need to use `ignore_literals=True` here to handle arbitrary dropout
# probability (not just 0.5). However, without a match filter, this would
# also match any mul op, since `div__scalar` is also a literal, e.g.:
#
# mul_tensor = torch.ops.aten.mul.Tensor(x, 0.8)
#
# Therefore, we need both `ignore_literals=True` and `_is_dropout_filter`
# to make sure we are in fact replacing the dropout pattern.
replace_pattern_with_filters(
m,
match_pattern,
replacement_pattern,
match_filters=[_is_dropout_filter],
ignore_literals=True,
)
m.recompile()