blob: 2f0295d8883352f10deb2ab93341056e9bcd8880 [file] [log] [blame]
import functools
import logging
import math
import numbers
import torch
import torch._decomp as decomp
import torch.ao.quantization.fx._decomposed
from torch._decomp import core_aten_decompositions, get_decompositions
from torch._decomp.decompositions import pw_cast_for_opmath
from torch._decomp.decompositions_for_rng import extra_random_decomps
from . import config
log = logging.getLogger(__name__)
aten = torch.ops.aten
prims = torch.ops.prims
quantized_decomposed = torch.ops.quantized_decomposed
inductor_decompositions = get_decompositions(
[
aten.arange,
aten.bitwise_and_,
aten.bitwise_or_,
aten.clamp_min_,
aten.empty_like,
aten.flip,
aten.lcm,
aten.linalg_vector_norm,
aten.sin_,
aten.sqrt_,
aten.std,
aten.std_mean,
aten._to_copy,
aten.tril_indices,
aten.triu_indices,
aten.unsafe_split,
]
)
decompositions = {**core_aten_decompositions(), **inductor_decompositions}
def register_decomposition(ops):
for op in [ops] if callable(ops) else ops:
if op in decompositions:
log.warning("duplicate decomp: %s", ops)
return decomp.register_decomposition(ops, decompositions)
@register_decomposition(aten._unsafe_view.default)
def _unsafe_view(self, size):
# this makes pattern matching easier
return self.view(size)
# TODO: for now, inductor doesn't handle asserts
# because the condition is symbool -> tensor in the graph.
@register_decomposition([aten._assert_async.msg])
def assert_async_msg_decomp(tensor, msg):
return
@register_decomposition([aten.clamp])
@pw_cast_for_opmath
def clamp(x, min=None, max=None):
if min is not None:
x = x.clamp_min(min)
if max is not None:
x = x.clamp_max(max)
return x
# TorchInductor-only decomposition. It should not be taken to core.
# See https://github.com/pytorch/torchdynamo/pull/1120
@register_decomposition([aten.floor_divide.default])
def floordiv(a, b):
return aten.div.Tensor_mode(a, b, rounding_mode="floor")
# Not really sure how to put this into the main library. PrimTorch wants
# empty_permuted to go to the prim, and typically users don't really want
# to decompose to empty_strided (but inductor is OK with it, because we are
# cool with strides and everything goes to empty_strided)
@register_decomposition([aten.empty_permuted.default])
def empty_permuted(size, physical_layout, **kwargs):
perm = [0] * len(size)
for p, l in enumerate(physical_layout):
perm[l] = p
return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm)
@register_decomposition([aten.convolution_backward])
def convolution_backward(
grad_output,
input,
weight,
bias_sizes,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_mask,
):
if not output_mask[2] or grad_output.device.type != "cuda":
return NotImplemented
grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))
grad_inp, grad_weight, _ = aten.convolution_backward(
grad_output,
input,
weight,
bias_sizes,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
[output_mask[0], output_mask[1], False],
)
return (grad_inp, grad_weight, grad_bias)
@register_decomposition([aten.log2])
def log2(x):
return torch.log(x) * (1.0 / math.log(2.0))
@register_decomposition([aten.round.decimals])
def round_dec(x, decimals=0):
ten_pow_decimals = 10.0**decimals
return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals)
@register_decomposition([aten.all.default])
def all(input):
return torch.logical_not(torch.any(torch.logical_not(input)))
@register_decomposition([aten.all.dim])
def all_dim(input, dim, keepdim=False):
return torch.logical_not(torch.any(torch.logical_not(input), dim, keepdim))
# NB: this decomposition is not stride accurate, do not put it in the main
# library
@register_decomposition(aten.copy)
def copy(self, src, non_blocking=False):
intermediate = src.to(self, non_blocking)
if self.size() != intermediate.size():
return aten.expand_copy.default(intermediate, self.size())
else:
return intermediate
@register_decomposition([aten.baddbmm])
def baddbmm(self, batch1, batch2, beta=1, alpha=1):
result = torch.bmm(batch1, batch2)
if not isinstance(alpha, numbers.Number) or alpha != 1:
result = result * alpha
if beta == 0:
return result
if not isinstance(beta, numbers.Number) or beta != 1:
self = self * beta
return self + result
@register_decomposition([aten.cat.default])
def cat(tensors, dim=0):
if len(tensors) == 1:
return tensors[0].clone()
return NotImplemented
@register_decomposition([aten.conj_physical])
def conj_physical(self):
assert not self.is_complex(), "TODO: implement this"
return self
@register_decomposition([aten.lift, aten.detach_])
def lift(self):
return self
@register_decomposition([aten.bernoulli.default])
def bernoulli(self, *, generator=None):
assert generator is None
return torch.rand_like(self, dtype=torch.float32) < self
@register_decomposition([aten.fmin, prims.fmin])
def fmin(self, other):
return torch.where(torch.isnan(other) | (other > self), self, other)
@register_decomposition([aten.fmax, prims.fmax])
def fmax(self, other):
return torch.where(torch.isnan(other) | (other < self), self, other)
@register_decomposition([aten.narrow_copy])
def narrow_copy(self, dim, start, length):
return torch.narrow(self, dim, start, length).clone()
@register_decomposition([aten.expand_copy])
def expand_copy(self, size, *, implicit=False):
return aten.expand(self, size, implicit=implicit).clone()
@register_decomposition([aten.view_copy.default])
def view_copy_default(self, size):
return aten.view(self, size).clone()
@register_decomposition([aten.view_copy.dtype])
def view_copy_dtype(self, dtype):
return self.to(dtype).clone()
@register_decomposition(aten.rand_like)
def rand_like(self, *, dtype=None, device=None, **kwargs):
return torch.rand(
[*self.size()],
dtype=dtype or self.dtype,
device=device or self.device,
**kwargs,
)
@register_decomposition(aten.randn_like)
def randn_like(self, *, dtype=None, device=None, **kwargs):
return torch.randn(
[*self.size()],
dtype=dtype or self.dtype,
device=device or self.device,
**kwargs,
)
@register_decomposition(aten.full_like)
def full_like(
self,
fill_value,
*,
dtype=None,
layout=None,
device=None,
pin_memory=False,
requires_grad=False,
memory_format=torch.preserve_format,
):
return torch.full(
[*self.size()],
fill_value,
dtype=dtype or self.dtype,
layout=layout or self.layout,
device=device or self.device,
requires_grad=requires_grad or self.requires_grad,
)
@register_decomposition(aten.randint_like.default)
def randint_like(self, high, *, dtype=None, device=None, **kwargs):
return aten.randint.low(
0,
high,
[*self.size()],
dtype=dtype or self.dtype,
device=device or self.device,
**kwargs,
)
@register_decomposition(aten.randint_like.low_dtype)
def randint_like_low(self, low, high, *, dtype=None, device=None, **kwargs):
return aten.randint.low(
low,
high,
[*self.size()],
dtype=dtype or self.dtype,
device=device or self.device,
**kwargs,
)
@register_decomposition(aten.randint.default)
def randint(high, size, **kwargs):
return aten.randint.low(0, high, size, **kwargs)
# The difference between quantize_per_tensor.default and quantize_per_tensor.tensor is
# scale and zero_point is scalar or scalar tensor
@register_decomposition(quantized_decomposed.quantize_per_tensor.default)
def quantize_per_tensor_default_decomp_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
inv_scale = 1.0 / scale
return torch.clamp(
torch.round(input * inv_scale) + zero_point, quant_min, quant_max
).to(dtype)
# The difference between dequantize_per_tensor.default and dequantize_per_tensor.tensor is
# scale and zero_point is scalar or scalar tensor
@register_decomposition(quantized_decomposed.dequantize_per_tensor.default)
def dequantize_per_tensor_default_decomp_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return (input.to(torch.float32) - zero_point) * scale
@register_decomposition(quantized_decomposed.quantize_per_tensor.tensor)
def quantize_per_tensor_tensor_decomp_impl(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
inv_scale = 1.0 / scale
return torch.clamp(
torch.round(input * inv_scale) + zero_point, quant_min, quant_max
).to(dtype)
@register_decomposition(quantized_decomposed.dequantize_per_tensor.tensor)
def dequantize_per_tensor_tensor_decomp_impl(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return (input.to(torch.float32) - zero_point) * scale
@register_decomposition(aten._foreach_addcmul.Scalar)
def _foreach_addcmul_scalar(self, left_tensors, right_tensors, scalar=1):
return aten._foreach_add.List(
self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar
)
@register_decomposition(aten._foreach_addcdiv.Scalar)
def _foreach_addcdiv_scalar(self, left_tensors, right_tensors, scalar=1):
return aten._foreach_add.List(
self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar
)
@functools.lru_cache(None)
def fast_random_decomps():
return {**decompositions, **extra_random_decomps}
def select_decomp_table():
"""decomps can change based on config"""
if config.fallback_random:
return decompositions
return fast_random_decomps()