| import copy |
| import itertools |
| import logging |
| import operator |
| import random |
| import weakref |
| |
| import torch |
| import torch.nn as nn |
| from torch import _prims |
| from torch.fx.experimental.optimization import ( |
| matches_module_pattern, |
| replace_node_module, |
| ) |
| from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode |
| from torch.fx.passes.shape_prop import ShapeProp |
| from torch.nn import functional as F |
| from torch.nn.modules.utils import _pair |
| from torch.nn.utils.fusion import fuse_conv_bn_eval |
| from torch.overrides import TorchFunctionMode |
| |
| from . import config |
| |
| log = logging.getLogger(__name__) |
| |
| |
| class AutogradMonkeypatch(TorchFunctionMode): |
| def __torch_function__(self, func, types, args=(), kwargs=None): |
| if not kwargs: |
| kwargs = {} |
| if func is replacements: |
| return replacements[func](*args, **kwargs) |
| return func(*args, **kwargs) |
| |
| |
| patch_functions = AutogradMonkeypatch |
| |
| |
| def replace_fx(gm: torch.fx.GraphModule): |
| # Sometimes patch_functions() misses things already in the graph |
| for node in reversed(list(gm.graph.nodes)): |
| if node.op == "call_function" and node.target in replacements: |
| with gm.graph.inserting_before(node): |
| node.replace_all_uses_with( |
| gm.graph.call_function( |
| replacements[node.target], node.args, node.kwargs |
| ) |
| ) |
| gm.graph.erase_node(node) |
| gm.recompile() |
| return gm |
| |
| |
| class UnaryAttr(object): |
| def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): |
| self.op_name = op_name |
| self.scalars_attr = scalars_attr if scalars_attr else [] |
| self.algorithm_attr = algorithm_attr if algorithm_attr else "" |
| super(UnaryAttr, self).__init__() |
| |
| def __call__(self, unary_module: nn.Module): |
| assert all(hasattr(unary_module, item) for item in self.scalars_attr) |
| scalars = [getattr(unary_module, item) for item in self.scalars_attr] |
| |
| algorithm = "" |
| if self.algorithm_attr: |
| assert hasattr(unary_module, self.algorithm_attr) |
| algorithm = getattr(unary_module, self.algorithm_attr) |
| |
| return self.op_name, scalars, algorithm |
| |
| |
| class ConvUnary2d(nn.Conv2d): |
| def __init__( |
| self, |
| conv: nn.Module, |
| unary: nn.Module, |
| ): |
| super(ConvUnary2d, self).__init__( |
| conv.in_channels, |
| conv.out_channels, |
| conv.kernel_size, |
| conv.stride, |
| conv.padding, |
| conv.dilation, |
| conv.groups, |
| conv.bias is not None, |
| conv.padding_mode, |
| conv.weight.device, |
| conv.weight.dtype, |
| ) |
| self._update_module_params(conv, unary) |
| |
| def _update_module_params(self, conv, unary): |
| self.__dict__ = copy.deepcopy(conv.__dict__) |
| self.attr, self.scalars, self.algorithm = unary_modules_map[unary.__class__]( |
| unary |
| ) |
| |
| def _conv_forward(self, input, weight, bias): |
| if self.padding_mode != "zeros": |
| return torch.ops.mkldnn._convolution_pointwise( |
| F.pad( |
| input, self._reversed_padding_repeated_twice, mode=self.padding_mode |
| ), |
| weight, |
| bias, |
| _pair(0), |
| self.stride, |
| self.dilation, |
| self.groups, |
| self.attr, |
| self.scalars, |
| self.algorithm, |
| ) |
| return torch.ops.mkldnn._convolution_pointwise( |
| input, |
| weight, |
| bias, |
| self.padding, |
| self.stride, |
| self.dilation, |
| self.groups, |
| self.attr, |
| self.scalars, |
| self.algorithm, |
| ) |
| |
| def forward(self, input): |
| return self._conv_forward(input, self.weight, self.bias) |
| |
| |
| class ConvBinary2d(nn.Conv2d): |
| def __init__( |
| self, |
| conv: nn.Module, |
| binary_op_name: str, |
| ): |
| super(ConvBinary2d, self).__init__( |
| conv.in_channels, |
| conv.out_channels, |
| conv.kernel_size, |
| conv.stride, |
| conv.padding, |
| conv.dilation, |
| conv.groups, |
| conv.bias is not None, |
| conv.padding_mode, |
| conv.weight.device, |
| conv.weight.dtype, |
| ) |
| self._update_module_params(conv, binary_op_name) |
| |
| def _update_module_params(self, conv, binary_op_name): |
| self.__dict__ = copy.deepcopy(conv.__dict__) |
| self.binary_attr = binary_op_name |
| self.binary_alpha = None |
| self.unary_attr = None |
| self.unary_scalars = [] |
| self.unary_algorithm = None |
| |
| def _update_unary_params(self, unary): |
| self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[ |
| unary.__class__ |
| ](unary) |
| |
| def _conv_forward(self, input, other, weight, bias): |
| if self.padding_mode != "zeros": |
| return torch.ops.mkldnn._convolution_pointwise( |
| F.pad( |
| input, self._reversed_padding_repeated_twice, mode=self.padding_mode |
| ), |
| other, |
| weight, |
| bias, |
| _pair(0), |
| self.stride, |
| self.dilation, |
| self.groups, |
| self.binary_attr, |
| self.binary_alpha, |
| self.unary_attr, |
| self.unary_scalars, |
| self.unary_algorithm, |
| ) |
| return torch.ops.mkldnn._convolution_pointwise( |
| input, |
| other, |
| weight, |
| bias, |
| self.padding, |
| self.stride, |
| self.dilation, |
| self.groups, |
| self.binary_attr, |
| self.binary_alpha, |
| self.unary_attr, |
| self.unary_scalars, |
| self.unary_algorithm, |
| ) |
| |
| def forward(self, input, other): |
| return self._conv_forward(input, other, self.weight, self.bias) |
| |
| |
| class ConvBinaryInplace2d(nn.Conv2d): |
| def __init__( |
| self, |
| conv: nn.Module, |
| binary_op_name: str, |
| ): |
| super(ConvBinaryInplace2d, self).__init__( |
| conv.in_channels, |
| conv.out_channels, |
| conv.kernel_size, |
| conv.stride, |
| conv.padding, |
| conv.dilation, |
| conv.groups, |
| conv.bias is not None, |
| conv.padding_mode, |
| conv.weight.device, |
| conv.weight.dtype, |
| ) |
| self._update_module_params(conv, binary_op_name) |
| |
| def _update_module_params(self, conv, binary_op_name): |
| self.__dict__ = copy.deepcopy(conv.__dict__) |
| self.binary_attr = binary_op_name |
| self.binary_alpha = None |
| self.unary_attr = None |
| self.unary_scalars = [] |
| self.unary_algorithm = None |
| |
| def _update_unary_params(self, unary): |
| self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[ |
| unary.__class__ |
| ](unary) |
| |
| def _conv_forward(self, input, other, weight, bias): |
| if self.padding_mode != "zeros": |
| return torch.ops.mkldnn._convolution_pointwise_( |
| F.pad( |
| input, self._reversed_padding_repeated_twice, mode=self.padding_mode |
| ), |
| other, |
| weight, |
| bias, |
| _pair(0), |
| self.stride, |
| self.dilation, |
| self.groups, |
| self.binary_attr, |
| self.binary_alpha, |
| self.unary_attr, |
| self.unary_scalars, |
| self.unary_algorithm, |
| ) |
| return torch.ops.mkldnn._convolution_pointwise_( |
| input, |
| other, |
| weight, |
| bias, |
| self.padding, |
| self.stride, |
| self.dilation, |
| self.groups, |
| self.binary_attr, |
| self.binary_alpha, |
| self.unary_attr, |
| self.unary_scalars, |
| self.unary_algorithm, |
| ) |
| |
| def forward(self, input, other): |
| return self._conv_forward(input, other, self.weight, self.bias) |
| |
| |
| class LinearUnary(nn.Linear): |
| def __init__( |
| self, |
| linear: nn.Module, |
| unary: nn.Module, |
| ): |
| super(LinearUnary, self).__init__( |
| linear.in_features, |
| linear.out_features, |
| linear.bias is not None, |
| linear.weight.device, |
| linear.weight.dtype, |
| ) |
| self._update_module_params(linear, unary) |
| |
| def _update_module_params(self, linear, unary): |
| self.__dict__ = copy.deepcopy(linear.__dict__) |
| self.attr, self.scalars, self.algorithm = unary_modules_map[unary.__class__]( |
| unary |
| ) |
| |
| def forward(self, input): |
| y = torch.ops.mkldnn._linear_pointwise( |
| input, self.weight, self.bias, self.attr, self.scalars, self.algorithm |
| ) |
| return y |
| |
| |
| class LinearBinary(nn.Linear): |
| def __init__(self, linear: nn.Module, binary_op_name: str): |
| super(LinearBinary, self).__init__( |
| linear.in_features, |
| linear.out_features, |
| linear.bias is not None, |
| linear.weight.device, |
| linear.weight.dtype, |
| ) |
| self._update_module_params(linear, binary_op_name) |
| |
| def _update_module_params(self, linear, binary_op_name): |
| self.__dict__ = copy.deepcopy(linear.__dict__) |
| |
| self.attr = binary_op_name |
| |
| def forward(self, input, other): |
| y = torch.ops.mkldnn._linear_pointwise( |
| input, other, self.weight, self.bias, self.attr |
| ) |
| return y |
| |
| |
| def fused_conv_unary_eval(conv: nn.Module, unary: nn.Module): |
| assert not (conv.training), "Fusion only for eval!" |
| return ConvUnary2d( |
| conv, |
| unary, |
| ) |
| |
| |
| def fused_conv_binary_eval(conv: nn.Module, binary_op_name: str): |
| assert not (conv.training), "Fusion only for eval!" |
| return ConvBinary2d( |
| conv, |
| binary_op_name, |
| ) |
| |
| |
| def fused_conv_binary_inplace_eval(conv: nn.Module, binary_op_name: str): |
| assert not (conv.training), "Fusion only for eval!" |
| return ConvBinaryInplace2d( |
| conv, |
| binary_op_name, |
| ) |
| |
| |
| def fused_binary_unary_eval(conv_binary: nn.Module, unary: nn.Module): |
| assert not (conv_binary.training), "Fusion only for eval!" |
| # reuse origin conv module, and just update its' unary attr. |
| conv_binary._update_unary_params(unary) |
| return conv_binary |
| |
| |
| def is_bfloat16_module(m): |
| weight_is_bf16 = m.weight.dtype == torch.bfloat16 |
| bias_is_bf16 = m.bias is None or m.bias.dtype == torch.bfloat16 |
| return weight_is_bf16 and bias_is_bf16 |
| |
| |
| def fused_linear_unary_eval(linear: nn.Module, unary: nn.Module): |
| assert not (linear.training), "Fusion only for eval!" |
| return LinearUnary( |
| linear, |
| unary, |
| ) |
| |
| |
| def fused_linear_binary_eval(linear: nn.Module, attr: str): |
| assert not (linear.training), "Fusion only for eval!" |
| linear_binary = LinearBinary( |
| linear, |
| attr, |
| ) |
| return linear_binary |
| |
| |
| def check_node_kind(current_node, modules, node_kind): |
| if not isinstance(current_node, torch.fx.Node): |
| return False |
| if current_node.op != "call_module": |
| return False |
| if not isinstance(current_node.target, str): |
| return False |
| if current_node.target not in modules: |
| return False |
| if type(modules[current_node.target]) is not node_kind: |
| return False |
| return True |
| |
| |
| def check_node_is_binary(node): |
| return ( |
| (node.op == "call_function" and node.target in [torch.add, torch.sub]) |
| or ( |
| node.op == "call_function" |
| and node.target |
| in [operator.add, operator.iadd, operator.sub, operator.isub] |
| ) |
| or (node.op == "call_method" and node.target in ["add", "add_", "sub", "sub_"]) |
| ) |
| |
| |
| def check_binary_op_kwargs_is_default(node): |
| # For binary op, we hope the kwargs values are the default value: |
| # torch.sub(add)(input, other, *, alpha=1, out=None). |
| if len(node.args) > 2: |
| return False |
| if len(node.kwargs) > 0: |
| if "out" in node.kwargs and node.kwargs["out"] is not None: |
| return False |
| if "alpha" in node.kwargs and node.kwargs["alpha"] != 1.0: |
| return False |
| return True |
| |
| |
| def check_node_is_add_inplace(node): |
| return (node.op == "call_function" and node.target in [operator.iadd]) or ( |
| node.op == "call_method" and node.target in ["add_"] |
| ) |
| |
| |
| def fuse_fx(gm: torch.fx.GraphModule, example_inputs): |
| if config.permute_fusion: |
| # For linear permute fusion, we need to check input info to identify |
| # and perform proper permutation/transpose |
| ShapeProp(gm).propagate(*example_inputs) |
| gm = linear_permute_fusion(gm) |
| gm = permute_linear_fusion(gm) |
| gm = permute_matmul_fusion(gm) |
| |
| # make sure the autograd is disabled. |
| if torch.is_grad_enabled(): |
| return gm |
| if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()): |
| return gm |
| is_cpu = all( |
| example_input.device == torch.device("cpu") for example_input in example_inputs |
| ) |
| if not is_cpu: |
| return gm |
| gm = fuse_conv_bn(gm) |
| # For binary fusion, we need to check inputs info to make sure |
| # the binary inputs have same tensor info(device, dtype, and layout). |
| ShapeProp(gm).propagate(*example_inputs) |
| gm = fuse_unary(gm) |
| gm = fuse_binary_inplace(gm) |
| gm = fuse_binary(gm) |
| # why re-run fuse_unary? we want to enable conv+binary+unary fusion, |
| # such as conv+add+relu for vision model. |
| gm = fuse_unary(gm) |
| |
| return gm |
| |
| |
| def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False): |
| """ |
| Fuses Convolution/BN layers for inference purposes. |
| """ |
| patterns = [ |
| (torch.nn.Conv1d, torch.nn.BatchNorm1d), |
| (torch.nn.Conv2d, torch.nn.BatchNorm2d), |
| (torch.nn.Conv3d, torch.nn.BatchNorm3d), |
| ] |
| modules = dict(gm.named_modules()) |
| |
| for pattern in patterns: |
| for node in gm.graph.nodes: |
| if matches_module_pattern(pattern, node, modules): |
| if len(node.args[0].users) > 1: # Output of conv is used by other nodes |
| continue |
| conv = modules[node.args[0].target] |
| bn = modules[node.target] |
| eval_mode = all(not n.training for n in [conv, bn]) |
| if not eval_mode: |
| continue |
| if not bn.track_running_stats: |
| continue |
| fused_conv = fuse_conv_bn_eval(conv, bn) |
| replace_node_module(node.args[0], modules, fused_conv) |
| node.replace_all_uses_with(node.args[0]) |
| gm.graph.erase_node(node) |
| gm.graph.lint() |
| gm.recompile() |
| return gm |
| |
| |
| def fuse_unary(gm: torch.fx.GraphModule): |
| modules = dict(gm.named_modules()) |
| |
| for (unary_module, _), (computation_module, fuse_func,) in itertools.product( |
| unary_modules_map.items(), computation_op_unary_op_fusion_map.items() |
| ): |
| pattern = (computation_module, unary_module) |
| for node in gm.graph.nodes: |
| if matches_module_pattern(pattern, node, modules): |
| if ( |
| len(node.args[0].users) > 1 |
| ): # Output of computation_node is used by other nodes |
| continue |
| computation_node = modules[node.args[0].target] |
| unary_node = modules[node.target] |
| eval_mode = all(not n.training for n in [computation_node, unary_node]) |
| if not eval_mode: |
| continue |
| # TODO: support padding str input("valid", "same"). |
| if type(computation_node) in [nn.Conv2d] and isinstance( |
| computation_node.padding, str |
| ): |
| continue |
| # only fuse for linear when the dtype is bf16 |
| if type(computation_node) in [nn.Linear] and not is_bfloat16_module( |
| computation_node |
| ): |
| continue |
| fused_module = fuse_func(computation_node, unary_node) |
| replace_node_module(node.args[0], modules, fused_module) |
| |
| node.replace_all_uses_with(node.args[0]) |
| gm.graph.erase_node(node) |
| gm.graph.lint() |
| gm.recompile() |
| return gm |
| |
| |
| def _philox_rand_like_meta(input, seed, offset): |
| return _prims.TensorMeta(input) |
| |
| |
| def _philox_rand_like(input, seed, offset): |
| # placeholder only used in tracing |
| return torch.rand_like(input) |
| |
| |
| class NormalizedLinearNode: |
| def __init__(self, node: torch.fx.Node) -> None: |
| assert node.op == "call_function" |
| assert node.target in [torch.nn.functional.linear] |
| self.node: torch.fx.Node = node |
| |
| def get_input(self) -> torch.fx.Node: |
| if len(self.node.args) > 0: |
| return self.node.args[0] |
| else: |
| return self.node.kwargs["input"] |
| |
| def get_weight(self) -> torch.fx.Node: |
| if len(self.node.args) > 1: |
| return self.node.args[1] |
| else: |
| return self.node.kwargs["weight"] |
| |
| def get_bias(self) -> torch.fx.Node: |
| if len(self.node.args) > 2: |
| return self.node.args[2] |
| else: |
| return self.node.kwargs["bias"] |
| |
| |
| class NormalizedMatmulNode: |
| def __init__(self, node: torch.fx.Node) -> None: |
| assert node.op == "call_function" |
| assert node.target in [torch.bmm, torch.matmul] |
| self.node: torch.fx.Node = node |
| |
| def get_input(self) -> torch.fx.Node: |
| if len(self.node.args) > 0: |
| return self.node.args[0] |
| else: |
| return self.node.kwargs["input"] |
| |
| def get_other(self) -> torch.fx.Node: |
| if len(self.node.args) > 1: |
| return self.node.args[1] |
| else: |
| return self.node.kwargs["other"] |
| |
| |
| def check_permute(node: torch.fx.Node): |
| ranks = len(node.meta["tensor_meta"].shape) |
| if len(node.args) > 3: |
| permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] |
| elif ( |
| "permutation" in node.kwargs |
| and node.kwargs["permutation"] is not None |
| and len(node.kwargs["permutation"]) > 2 |
| ): |
| permutation = [i % ranks for i in node.kwargs["permutation"]] |
| else: |
| return False |
| allowed_permutation = list(range(ranks)) |
| allowed_permutation[-1] = ranks - 2 |
| allowed_permutation[-2] = ranks - 1 |
| return permutation == allowed_permutation |
| |
| |
| def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| for node in module.graph.nodes: |
| if ( |
| node.op == "call_method" |
| and node.target == "permute" |
| and check_permute(node) |
| ): |
| if len(node.args) > 0: |
| input_node = node.args[0] |
| else: |
| input_node = node.kwargs["input"] |
| if ( |
| input_node.op == "call_function" |
| and input_node.target == torch.nn.functional.linear |
| ): |
| normalized = NormalizedLinearNode(input_node) |
| input = normalized.get_input() |
| weight = normalized.get_weight() |
| bias = normalized.get_bias() |
| with module.graph.inserting_before(node): |
| fused_node = module.graph.call_function( |
| linear_transpose, args=(input, weight, bias) |
| ) |
| node.replace_all_uses_with(fused_node) |
| |
| module.graph.lint() |
| module.graph.eliminate_dead_code() |
| module.recompile() |
| return module |
| |
| |
| # Y1 = X * W^T + bias |
| # Y2 = Y1.permute(0, 2, 1) |
| # ----> |
| # Y2 = (W * X^T + bias.unsqueeze(-1))^T |
| def linear_transpose( |
| input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor |
| ) -> torch.Tensor: |
| return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1) |
| |
| |
| def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| for node in module.graph.nodes: |
| if node.op == "call_function" and node.target == torch.nn.functional.linear: |
| if len(node.args) > 0: |
| input_node = node.args[0] |
| else: |
| input_node = node.kwargs["input"] |
| if ( |
| input_node.op == "call_method" |
| and input_node.target == "permute" |
| and check_permute(input_node) |
| ): |
| normalized = NormalizedLinearNode(node) |
| if len(input_node.args) > 0: |
| input = input_node.args[0] |
| else: |
| input = input_node.kwargs["input"] |
| weight = normalized.get_weight() |
| bias = normalized.get_bias() |
| with module.graph.inserting_before(node): |
| fused_node = module.graph.call_function( |
| transpose_linear, args=(input, weight, bias) |
| ) |
| node.replace_all_uses_with(fused_node) |
| |
| module.graph.lint() |
| module.graph.eliminate_dead_code() |
| module.recompile() |
| return module |
| |
| |
| def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| for node in module.graph.nodes: |
| if node.op == "call_function" and ( |
| node.target == torch.bmm or node.target == torch.matmul |
| ): |
| normalized = NormalizedMatmulNode(node) |
| A = normalized.get_input() |
| B = normalized.get_other() |
| Atrans = Btrans = False |
| if A.op == "call_method" and A.target == "permute" and check_permute(A): |
| Atrans = True |
| if len(A.args) > 0: |
| A = A.args[0] |
| else: |
| A = A.kwargs["input"] |
| |
| if B.op == "call_method" and B.target == "permute" and check_permute(B): |
| Btrans = True |
| if len(B.args) > 0: |
| B = B.args[0] |
| else: |
| B = B.kwargs["input"] |
| |
| if Atrans or Btrans: |
| with module.graph.inserting_before(node): |
| fused_node = module.graph.call_function( |
| transpose_matmul, |
| args=(A, B, Atrans, Btrans), |
| ) |
| node.replace_all_uses_with(fused_node) |
| |
| module.graph.lint() |
| module.graph.eliminate_dead_code() |
| module.recompile() |
| return module |
| |
| |
| # X1 = X.permute(0, 2, 1) |
| # Y1 = X1 * W1^T + bias1 |
| # ----> |
| # Y2 = X1.transpose(-1, -2) * W1^T + bias1 |
| def transpose_linear( |
| input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor |
| ) -> torch.Tensor: |
| return torch.matmul(input.transpose(-1, -2), weight.t()) + bias |
| |
| |
| def transpose_matmul(A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool): |
| if Atrans: |
| A = A.transpose(-1, -2) |
| if Btrans: |
| B = B.transpose(-1, -2) |
| return torch.matmul(A, B) |
| |
| |
| def replace_and_fuse_for_binary( |
| computation_node, node, fuse_func, attr, modules, index_node, index_pointwise |
| ): |
| fused_module = fuse_func(computation_node, attr) |
| replace_node_module(node.args[index_node], modules, fused_module) |
| node.args[index_node].args = node.args[index_node].args + ( |
| node.args[index_pointwise], |
| ) |
| node.replace_all_uses_with(node.args[index_node]) |
| |
| |
| def binary_inputs_meta_is_same(binary_node): |
| tensor0_meta = binary_node.args[0].meta.get("tensor_meta") |
| tensor1_meta = binary_node.args[1].meta.get("tensor_meta") |
| if not tensor0_meta or not tensor1_meta: |
| return False |
| if ( |
| tensor0_meta.shape != tensor1_meta.shape |
| or tensor0_meta.stride != tensor1_meta.stride |
| or tensor0_meta.dtype != tensor1_meta.dtype |
| ): |
| return False |
| |
| return True |
| |
| |
| def fuse_binary(gm: torch.fx.GraphModule): |
| modules = dict(gm.named_modules()) |
| for node in gm.graph.nodes: |
| if check_node_is_binary(node) and check_binary_op_kwargs_is_default(node): |
| for node_kind, fuse_func in computation_op_binary_op_fusion_map.items(): |
| if not isinstance(node.args[0], torch.fx.Node) or not isinstance( |
| node.args[1], torch.fx.Node |
| ): |
| continue |
| if not binary_inputs_meta_is_same(node): |
| continue |
| attr = binary_attr[node.target] |
| index_list = supported_index_list[attr] |
| for index_dict in index_list: |
| index_node = index_dict["index_computation"] |
| index_pointwise = index_dict["index_pointwise"] |
| if check_node_kind(node.args[index_node], modules, node_kind): |
| if len(node.args[index_node].users) > 1: |
| continue |
| computation_node = modules[node.args[index_node].target] |
| # TODO: support padding str input("valid", "same"). |
| if type(computation_node) in [nn.Conv2d] and isinstance( |
| computation_node.padding, str |
| ): |
| continue |
| # only fuse for linear when the dtype is bf16 |
| if type(computation_node) in [ |
| nn.Linear |
| ] and not is_bfloat16_module(computation_node): |
| continue |
| replace_and_fuse_for_binary( |
| computation_node, |
| node, |
| fuse_func, |
| attr if attr != "iadd" else "add", |
| modules, |
| index_node, |
| index_pointwise, |
| ) |
| # Make sure the fused node is post node of node's inputs nodes. |
| node.append(node.args[index_node]) |
| gm.graph.erase_node(node) |
| gm.graph.lint() |
| break |
| |
| gm.recompile() |
| return gm |
| |
| |
| def fuse_binary_inplace(gm: torch.fx.GraphModule): |
| modules = dict(gm.named_modules()) |
| for node in gm.graph.nodes: |
| if check_node_is_add_inplace(node) and check_binary_op_kwargs_is_default(node): |
| for ( |
| node_kind, |
| fuse_func, |
| ) in computation_op_binary_op_fusion_inplace_map.items(): |
| if not isinstance(node.args[0], torch.fx.Node) or not isinstance( |
| node.args[1], torch.fx.Node |
| ): |
| continue |
| if not binary_inputs_meta_is_same(node): |
| continue |
| if check_node_kind(node.args[1], modules, node_kind): |
| if len(node.args[1].users) > 1: |
| continue |
| # make sure the output and input are not same tensor. |
| if node.args[1].args[0] == node.args[0]: |
| continue |
| computation_node = modules[node.args[1].target] |
| # TODO: support padding str input("valid", "same"). |
| if type(computation_node) in [nn.Conv2d] and isinstance( |
| computation_node.padding, str |
| ): |
| continue |
| replace_and_fuse_for_binary( |
| computation_node, |
| node, |
| fuse_func, |
| "add", |
| modules, |
| 1, # conv module index |
| 0, # binary op index |
| ) |
| # Make sure the fused node is post node of node's inputs nodes. |
| node.append(node.args[1]) |
| gm.graph.erase_node(node) |
| gm.graph.lint() |
| break |
| |
| gm.recompile() |
| return gm |
| |
| |
| philox_rand_like = _prims._make_prim( |
| schema="philox_rand_like(Tensor input, Tensor seed, int offset) -> Tensor", |
| return_type=_prims.RETURN_TYPE.NEW, |
| meta=_philox_rand_like_meta, |
| impl_aten=_philox_rand_like, |
| doc="", |
| ) |
| |
| |
| def _philox_seed_like_meta(x): |
| return _prims.TensorMeta(_philox_seed_like(x)) |
| |
| |
| def _philox_seed_like(x): |
| # we need a tensor input here so AOT autograd properly captures this |
| # with just a device input, this becomes a constant |
| return torch.tensor(random.randrange(2**31), device=x.device, dtype=torch.int32) |
| |
| |
| philox_seed_like = _prims._make_prim( |
| schema="philox_seed_like(Tensor other) -> Tensor", |
| return_type=_prims.RETURN_TYPE.NEW, |
| meta=_philox_seed_like_meta, |
| impl_aten=_philox_seed_like, |
| doc="", |
| ) |
| |
| |
| def null_ref(): |
| return None |
| |
| |
| class PhiloxRandomState: |
| next_offset = 0 |
| seed = {} |
| last_tracer_ref = null_ref |
| |
| @classmethod |
| def reset(cls, tracer=None): |
| cls.next_offset = 0 |
| cls.seed = {} |
| cls.last_tracer_ref = weakref.ref(tracer) if tracer is not None else null_ref |
| |
| @classmethod |
| def get_seed_offset(cls, x): |
| modes = torch.fx.experimental.proxy_tensor.get_torch_dispatch_modes() |
| proxy_modes = [m for m in modes if isinstance(m, ProxyTorchDispatchMode)] |
| if proxy_modes: |
| tracer = proxy_modes[0].tracer |
| if cls.last_tracer_ref() is not tracer: |
| # tracer changed, need to reset state |
| cls.reset(tracer) |
| else: |
| # no tracer, need to reset state |
| cls.reset() |
| |
| device = x.device |
| if device not in cls.seed: |
| # Compute the seed just once per trace so that we pass fewer |
| # things from forward to backward |
| cls.seed[device] = philox_seed_like(x) |
| |
| seed = cls.seed[device] |
| offset = cls.next_offset |
| cls.next_offset += x.numel() |
| return seed, offset |
| |
| |
| class LowmemDropout(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, p): |
| ctx.p = p |
| scale = float(1.0 / (1.0 - p)) |
| seed, offset = PhiloxRandomState.get_seed_offset(x) |
| ctx.save_for_backward(seed) |
| ctx.offset = offset |
| bool_mask = philox_rand_like(x, seed, offset) > p |
| return bool_mask.to(x.dtype) * x * scale |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| p = ctx.p |
| scale = float(1.0 / (1.0 - p)) |
| (seed,) = ctx.saved_tensors |
| bool_mask = philox_rand_like(grad_output, seed, ctx.offset) > p |
| return bool_mask.to(grad_output.dtype) * grad_output * scale, None |
| |
| |
| @torch.fx.wrap |
| def lowmem_dropout(input, p=0.5, training=True, inplace=False): |
| if isinstance(input, torch.fx.Proxy): |
| # double check we don't FX trace this |
| return input.tracer.create_proxy( |
| "call_function", |
| lowmem_dropout, |
| (input, p, training), |
| {}, |
| ) |
| if not training or p == 0: |
| return input |
| result = LowmemDropout.apply(input, p) |
| if inplace: |
| input.copy_(result) |
| return result |
| |
| |
| @torch.fx.wrap |
| def rand_like(x, **kwargs): |
| if isinstance(x, torch.fx.Proxy): |
| # double check we don't FX trace this |
| return x.tracer.create_proxy("call_function", rand_like, (x), kwargs) |
| assert kwargs.get("device", x.device) == x.device |
| seed, offset = PhiloxRandomState.get_seed_offset(x) |
| return philox_rand_like(x, seed, offset).to(kwargs.get("dtype", torch.float32)) |
| |
| |
| replacements = {torch.nn.functional.dropout: lowmem_dropout, torch.rand_like: rand_like} |
| |
| |
| computation_op_unary_op_fusion_map = { |
| nn.Conv2d: fused_conv_unary_eval, |
| nn.Linear: fused_linear_unary_eval, |
| ConvBinary2d: fused_binary_unary_eval, |
| ConvBinaryInplace2d: fused_binary_unary_eval, |
| } |
| |
| |
| unary_modules_map = { |
| nn.ReLU: UnaryAttr("relu"), |
| nn.Sigmoid: UnaryAttr("sigmoid"), |
| nn.Tanh: UnaryAttr("tanh"), |
| nn.Hardswish: UnaryAttr("hardswish"), |
| nn.LeakyReLU: UnaryAttr("leaky_relu", scalars_attr=["negative_slope"]), |
| nn.Hardtanh: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]), |
| nn.GELU: UnaryAttr("gelu", algorithm_attr="approximate"), |
| } |
| |
| |
| binary_attr = { |
| torch.add: "add", # node.op == "call_function" |
| "add": "add", # node.op == "call_method" |
| "add_": "iadd", # node.op == "call_method" |
| operator.add: "add", # node.op == "call_function" |
| operator.iadd: "iadd", # node.op == "call_function" |
| torch.sub: "sub", # node.op == "call_function" |
| "sub": "sub", # node.op == "call_method" |
| "sub_": "sub", # node.op == "call_method" |
| operator.sub: "sub", # node.op == "call_function" |
| operator.isub: "sub", # node.op == "call_function" |
| } |
| |
| |
| computation_op_binary_op_fusion_map = { |
| nn.Conv2d: fused_conv_binary_eval, |
| nn.Linear: fused_linear_binary_eval, |
| } |
| |
| |
| computation_op_binary_op_fusion_inplace_map = { |
| nn.Conv2d: fused_conv_binary_inplace_eval, |
| } |
| |
| # For add: we support conv/linear + other and other + conv |
| # For sub/add_/sub_, we only support conv/linear - other |
| # or conv/linear +(-)= other |
| supported_index_list = { |
| "add": [ |
| {"index_computation": 0, "index_pointwise": 1}, |
| {"index_computation": 1, "index_pointwise": 0}, |
| ], |
| "iadd": [{"index_computation": 0, "index_pointwise": 1}], |
| "sub": [{"index_computation": 0, "index_pointwise": 1}], |
| } |