blob: 6cddc0f489c55f975e8da82615b877b903a66b02 [file] [log] [blame]
import functools
import logging
import math
import numbers
import torch
import torch._decomp as decomp
from torch import Tensor
from torch._decomp import get_decompositions
from torch._prims_common import is_boolean_dtype, is_integer_dtype
from torch.utils._mode_utils import no_dispatch
from . import config, utils
log = logging.getLogger(__name__)
aten = torch.ops.aten
log = logging.getLogger(__name__)
decompositions = get_decompositions(
[
aten._adaptive_avg_pool2d_backward,
aten.addcmul,
aten.avg_pool2d_backward,
aten.binary_cross_entropy_with_logits,
aten.clamp_max,
aten.clamp_min,
aten.col2im,
aten.cudnn_batch_norm,
aten.cudnn_batch_norm_backward,
aten.detach,
aten.dot,
aten.elu,
aten.elu_backward,
aten._embedding_bag,
aten.embedding_dense_backward,
aten.expand_as,
aten.eye,
aten.flip,
aten._fused_moving_avg_obs_fq_helper,
aten.gelu,
aten.gelu_backward,
aten.glu_backward,
aten.grid_sampler_2d,
aten.hardsigmoid,
aten.hardsigmoid_backward,
aten.hardswish,
aten.hardswish_backward,
aten.hardtanh,
aten.hardtanh_backward,
aten.im2col,
aten.index_add,
aten.index_add_,
aten.index_select,
aten.l1_loss,
aten.leaky_relu,
aten.leaky_relu_backward,
aten.linalg_vector_norm,
aten.logit,
aten.logit_backward,
aten._log_softmax,
aten._log_softmax_backward_data,
aten.logsumexp.default,
aten.max_pool2d_with_indices_backward,
aten.mse_loss,
aten.mse_loss_backward,
aten.mv,
aten.narrow,
aten.native_batch_norm,
aten.native_batch_norm_backward,
aten.native_dropout_backward,
aten.native_group_norm,
aten.native_group_norm_backward,
aten.native_layer_norm,
aten.native_layer_norm_backward,
aten.new_empty,
aten.new_full,
aten.new_ones,
aten.nll_loss_backward,
aten.nll_loss_forward,
aten.norm,
aten.reflection_pad2d_backward,
aten._reshape_alias,
aten.select_backward,
aten.select_scatter,
aten.sgn,
aten.sigmoid_backward,
aten.silu,
aten.silu_backward,
aten.slice_backward,
aten._softmax,
aten._softmax_backward_data,
aten.softplus,
aten.softplus_backward,
aten.stack,
aten.std_mean.correction,
aten.t,
aten.tanh_backward,
aten.threshold_backward,
aten.transpose.int,
aten.tril.default,
aten.unfold,
aten.unfold_backward,
aten.upsample_bilinear2d.vec,
aten.upsample_nearest2d_backward,
aten.softplus,
aten.softplus_backward,
aten.bucketize,
]
)
def register_decomposition(ops):
for op in [ops] if callable(ops) else ops:
if op in decompositions:
log.warning(f"duplicate decomp: {ops}")
return decomp.register_decomposition(ops, decompositions)
@register_decomposition([aten.clamp])
def clamp(x, min=None, max=None):
if min is not None:
x = torch.maximum(x, torch.tensor(min, dtype=x.dtype, device=x.device))
if max is not None:
x = torch.minimum(x, torch.tensor(max, dtype=x.dtype, device=x.device))
return x
@register_decomposition([aten.tanh])
def tanh(x):
return 2.0 / (1.0 + torch.exp(-2.0 * x)) - 1.0
# TorchInductor-only decomposition. It should not be taken to core.
# See https://github.com/pytorch/torchdynamo/pull/1120
@register_decomposition([aten.floor_divide.default])
def floordiv(a, b):
return aten.div.Tensor_mode(a, b, rounding_mode="floor")
def get_padded_length(x):
if x % config.alignment_size == 0:
return 0
return int((x // config.alignment_size + 1) * config.alignment_size) - x
def pad_dim(x, padded_length, dim):
pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :])
return torch.cat([x, pad], dim=dim)
def check_device_dtype(a: Tensor, b: Tensor):
return (
a.is_cuda
and b.is_cuda
and a.dtype == torch.float32
and b.dtype == torch.float32
)
@register_decomposition([aten.addmm])
def addmm(input, mat1, mat2, *, beta=1, alpha=1):
if config.triton.mm != "aten":
out = torch.mm(mat1, mat2)
if not isinstance(alpha, numbers.Number) or alpha != 1:
out = out * alpha
if not isinstance(beta, numbers.Number) or beta != 1:
input = input * beta
return input + out
if (
config.shape_padding
and check_device_dtype(mat1, mat2)
and should_pad_bench(mat1, mat2, torch.ops.aten.addmm, input=input)
):
m_padded_length = get_padded_length(mat1.shape[0])
k_padded_length = get_padded_length(mat1.shape[1])
n_padded_length = get_padded_length(mat2.shape[1])
if k_padded_length != 0:
mat1 = pad_dim(mat1, k_padded_length, 1)
mat2 = pad_dim(mat2, k_padded_length, 0)
elif m_padded_length != 0:
mat1 = pad_dim(mat1, m_padded_length, 0)
elif n_padded_length != 0:
mat2 = pad_dim(mat2, n_padded_length, 1)
if input is not None and k_padded_length == 0:
if m_padded_length != 0 and input.dim() == 2:
input = pad_dim(input, m_padded_length, 0)
elif n_padded_length != 0:
if input.dim() == 2:
input = pad_dim(input, n_padded_length, 1)
elif input.dim() == 1:
input = pad_dim(input, n_padded_length, 0)
if k_padded_length != 0:
return torch.ops.aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
elif m_padded_length != 0:
return torch.ops.aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)[
:-m_padded_length, :
]
elif n_padded_length != 0:
return torch.ops.aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)[
:, :-n_padded_length
]
return NotImplemented # go directly to lowering
def should_pad_bench(mat1, mat2, op, input=None):
with no_dispatch():
if op is torch.ops.aten.mm or op is torch.ops.aten.addmm:
m_padded_length = get_padded_length(mat1.shape[0])
k_padded_length = get_padded_length(mat1.shape[1])
n_padded_length = get_padded_length(mat2.shape[1])
elif op is torch.ops.aten.bmm:
m_padded_length = get_padded_length(mat1.shape[1])
k_padded_length = get_padded_length(mat1.shape[2])
n_padded_length = get_padded_length(mat2.shape[2])
else:
return False
if m_padded_length == k_padded_length == n_padded_length == 0:
return False
mat1 = torch.randn_like(mat1)
mat2 = torch.randn_like(mat2)
warmup = 5
rep = 100
if op is torch.ops.aten.bmm or op is torch.ops.aten.mm:
ori_time = utils.do_bench(
lambda: op(mat1, mat2), warmup=warmup, rep=rep, fast_flush=True
)[0]
else:
if input is not None:
input = torch.randn_like(input)
ori_time = utils.do_bench(
lambda: op(input, mat1, mat2), warmup=warmup, rep=rep, fast_flush=True
)[0]
mat1_pad = mat1.new_empty([get_padded_length(i) + i for i in mat1.shape])
mat2_pad = mat2.new_empty([get_padded_length(i) + i for i in mat2.shape])
if op is torch.ops.aten.addmm:
input_pad = None
if input is not None and input.is_cuda and input.dtype == torch.float32:
input_pad = input.new_empty(
[get_padded_length(i) + i for i in input.shape]
)
pad_time = utils.do_bench(
lambda: op(input_pad, mat1_pad, mat2_pad),
warmup=warmup,
rep=rep,
fast_flush=True,
)[0]
else:
pad_time = utils.do_bench(
lambda: op(mat1_pad, mat2_pad), warmup=warmup, rep=rep, fast_flush=True
)[0]
# Shape padding introduces addtional memory ops. Based on microbenchmarks, 1.3x for
# aten.mm and aten.addmm and 2x for aten.bmm represent a reasonable tradeoff between
# performance improvement from shape padding and overhead from addtional memory ops
# TODO: Build a learned model which would be better than this heuristic
if op is torch.ops.aten.mm or op is torch.ops.aten.addmm:
return ori_time > pad_time * 1.3
else:
return ori_time > pad_time * 2
@register_decomposition([aten.mm])
def mm_decomp(mat1, mat2):
if (
config.shape_padding
and check_device_dtype(mat1, mat2)
and should_pad_bench(mat1, mat2, torch.ops.aten.mm)
):
m_padded_length = get_padded_length(mat1.shape[0])
k_padded_length = get_padded_length(mat1.shape[1])
n_padded_length = get_padded_length(mat2.shape[1])
if k_padded_length != 0:
mat1 = pad_dim(mat1, k_padded_length, 1)
mat2 = pad_dim(mat2, k_padded_length, 0)
return torch.ops.aten.mm(mat1, mat2)
elif m_padded_length != 0:
mat1 = pad_dim(mat1, m_padded_length, 0)
return torch.ops.aten.mm(mat1, mat2)[:-m_padded_length, :]
elif n_padded_length != 0:
mat2 = pad_dim(mat2, n_padded_length, 1)
return torch.ops.aten.mm(mat1, mat2)[:, :-n_padded_length]
return NotImplemented # go directly to lowering
@register_decomposition([aten.bmm])
def bmm_decomp(mat1, mat2):
if (
config.shape_padding
and check_device_dtype(mat1, mat2)
and should_pad_bench(mat1, mat2, torch.ops.aten.bmm)
):
m_padded_length = get_padded_length(mat1.shape[1])
k_padded_length = get_padded_length(mat1.shape[2])
n_padded_length = get_padded_length(mat2.shape[2])
if k_padded_length != 0:
mat1 = pad_dim(mat1, k_padded_length, 2)
mat2 = pad_dim(mat2, k_padded_length, 1)
return torch.ops.aten.bmm(mat1, mat2)
elif m_padded_length != 0:
mat1 = pad_dim(mat1, m_padded_length, 1)
return torch.ops.aten.bmm(mat1, mat2)[:, :-m_padded_length, :].contiguous()
elif n_padded_length != 0:
mat2 = pad_dim(mat2, n_padded_length, 2)
return torch.ops.aten.bmm(mat1, mat2)[:, :, :-n_padded_length].contiguous()
return NotImplemented # go directly to lowering
@register_decomposition([aten.convolution_backward])
def convolution_backward(
grad_output,
input,
weight,
bias_sizes,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_mask,
):
if not output_mask[2] or grad_output.device.type != "cuda":
return NotImplemented
grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))
grad_inp, grad_weight, _ = aten.convolution_backward(
grad_output,
input,
weight,
bias_sizes,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
[output_mask[0], output_mask[1], False],
)
return (grad_inp, grad_weight, grad_bias)
@register_decomposition([aten.rsqrt])
def rsqrt(x):
return torch.reciprocal(torch.sqrt(x))
@register_decomposition([aten.log2])
def log2(x):
return torch.log(x) * (1.0 / math.log(2.0))
@register_decomposition([aten.round.decimals])
def round_dec(x, decimals=0):
ten_pow_decimals = 10.0**decimals
return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals)
@register_decomposition([aten.rsub.Tensor, aten.rsub.Scalar])
def rsub(a, b):
if isinstance(b, numbers.Number):
b = torch.tensor(b, dtype=a.dtype, device=a.device)
return b - a
@register_decomposition([aten.masked_fill])
def masked_fill(value, mask, other):
if isinstance(other, numbers.Number):
other = torch.tensor(other, dtype=value.dtype, device=value.device)
assert other.numel() == 1 and other.ndim == 0
if other.device != value.device and other.numel() == 1:
other = other.to(value.device)
if other.dtype != value.dtype:
# TODO: error out on improper complex conversions
other = other.to(value.dtype)
return torch.where(mask, other, value)
@register_decomposition([aten.nan_to_num])
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
if is_boolean_dtype(x.dtype) or is_integer_dtype(x.dtype):
return x
if nan is None:
nan = 0.0
if posinf is None:
posinf = torch.finfo(x.dtype).max
if neginf is None:
neginf = torch.finfo(x.dtype).min
nan, posinf, neginf = (
torch.tensor(v, dtype=x.dtype, device=x.device) for v in (nan, posinf, neginf)
)
x = torch.where(x != x, nan, x)
x = torch.where(x == float("inf"), posinf, x)
x = torch.where(x == float("-inf"), neginf, x)
return x
@register_decomposition([aten.all.default])
def all(input):
return torch.logical_not(torch.any(torch.logical_not(input)))
@register_decomposition([aten.all.dim])
def all_dim(input, dim, keeepdim=False):
return torch.logical_not(torch.any(torch.logical_not(input), dim, keeepdim))
# NB: this decomposition is not stride accurate, do not put it in the main
# library
@register_decomposition(aten.copy)
def copy(self, src, non_blocking=False):
intermediate = src.to(self, non_blocking)
if self.size() != intermediate.size():
return aten.expand_copy.default(intermediate, self.size())
else:
return intermediate
@register_decomposition(aten.hardswish_)
def hardswish_(x):
return x.copy_(aten.hardswish(x))
@register_decomposition(aten.hardtanh_)
def hardtanh_(x, min_val=-1, max_val=1):
return x.copy_(aten.hardtanh(x, min_val, max_val))
@register_decomposition(aten.leaky_relu_)
def leaky_relu_(x, negative_slope=0.01):
return x.copy_(aten.leaky_relu(x, negative_slope))
@register_decomposition(aten.silu_)
def silu_(x):
return x.copy_(aten.silu(x))
@register_decomposition(aten.masked_fill_)
def masked_fill_(x, mask, value):
return x.copy_(aten.masked_fill(x, mask, value))
@register_decomposition([aten.log1p])
def log1p(x):
return torch.log(x + 1)
@register_decomposition([aten.baddbmm])
def baddbmm(self, batch1, batch2, beta=1, alpha=1):
result = torch.bmm(batch1, batch2)
if not isinstance(alpha, numbers.Number) or alpha != 1:
result = result * alpha
if not isinstance(beta, numbers.Number) or beta != 1:
self = self * beta
return self + result
@register_decomposition([aten.conj_physical])
def conj_physical(self):
assert not self.is_complex(), "TODO: implement this"
return self
@register_decomposition([aten.lift, aten.detach_])
def lift(self):
return self
@register_decomposition([aten.fill.Scalar])
def fill_scalar(self, value):
return torch.full_like(self, value)
@register_decomposition([aten.fill.Tensor])
def fill_tensor(self, value: Tensor):
assert value.dim() == 0, "aten.fill.Tensor only supports 0-dimension value tensor"
return torch.full_like(self, value.item())
@register_decomposition([aten.bernoulli.default])
def bernoulli(self, *, generator=None):
assert generator is None
return torch.rand_like(self, dtype=torch.float32) < self
@register_decomposition([aten.bernoulli.p])
def bernoulli_p(self, p=0.5, *, generator=None):
assert generator is None
return torch.rand_like(self, dtype=torch.float32) < p
"""
Some decomps result in differences from eager related to randomness.
We put these decomps in a separate table `extra_random_decomps` to allow
turning them on and off via `config.fallback_random`.
"""
extra_random_decomps = get_decompositions([aten.native_dropout])
register_extra_random_decomp = functools.partial(
decomp.register_decomposition, registry=extra_random_decomps
)
@register_extra_random_decomp([aten.bernoulli_])
def bernoulli_(self, p=0.5):
return self.copy_(torch.rand_like(self, dtype=torch.float32) < p)
@functools.lru_cache(None)
def fast_random_decomps():
return {**decompositions, **extra_random_decomps}
def select_decomp_table():
"""decomps can change based on config"""
if config.fallback_random:
return decompositions
return fast_random_decomps()