blob: 548874b33c968c52829abcc12c9cb93b2d39b409 [file] [log] [blame]
# Owner(s): ["module: inductor"]
import contextlib
import copy
import itertools
import math
import sys
import unittest
from typing import Callable
from unittest.mock import patch
import numpy as np
import sympy
import torch
from torch._C import FileCheck
from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import same
from torch._inductor import codecache, config, metrics
from torch._inductor.codegen.common import OptimizationContext
from torch._inductor.codegen.cpp import (
CppOverrides,
CppVecKernelChecker,
CppVecOverrides,
)
from torch._inductor.compile_fx import (
compile_fx,
compile_fx_inner,
complex_memory_overlap,
)
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import InterpreterShim
from torch._inductor.utils import timed
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn import functional as F
from torch.testing._internal.common_utils import IS_MACOS, slowTest
from torch.utils._python_dispatch import TorchDispatchMode
try:
try:
from . import test_torchinductor
except ImportError:
import test_torchinductor
except unittest.SkipTest:
if __name__ == "__main__":
sys.exit(0)
raise
vec_dtypes = test_torchinductor.vec_dtypes
_lowp_fp_dtypes = (
torch.bfloat16,
torch.float16,
)
run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code
TestCase = test_torchinductor.TestCase
aten = torch.ops.aten
check_model = test_torchinductor.check_model
class LstmModule(torch.nn.Module):
def __init__(
self,
input_size,
hidden_size,
num_layers,
bias=True,
bidirectional=False,
batch_first=False,
):
super().__init__()
self.lstm = torch.nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
bidirectional=bidirectional,
batch_first=batch_first,
)
def forward(self, x, h=None):
x, h = self.lstm(x, h)
return x, h
class CPUReproTests(TestCase):
common = check_model
def test_conv_stride_constraints(self):
for fmt in [torch.contiguous_format, torch.channels_last]:
# TorchDispatch doesn't work in our cuda invocation for some reason
m = torch.nn.Conv2d(5, 6, [3, 3])
def fn(inp, weight):
return (
F.conv2d(
inp, weight, None, m.stride, m.padding, m.dilation, m.groups
),
)
inp = torch.randn([2, 5, 16, 16])
inps = [inp, m.weight.to(memory_format=fmt)]
fn_fx = make_fx(fn)(*inps)
fn_compiled = compile_fx_inner(fn_fx, inps)
test_self = self
conv_seen = False
class RecordFunctions(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}
if func == torch.ops.aten.convolution.default:
# For CPU and mkldnn enable, we always using channles last
nonlocal fmt
if (
torch.backends.mkldnn.enabled
and torch.backends.mkldnn.is_available()
):
fmt = torch.channels_last
test_self.assertTrue(args[0].is_contiguous(memory_format=fmt))
test_self.assertTrue(args[1].is_contiguous(memory_format=fmt))
nonlocal conv_seen
conv_seen = True
return func(*args, **kwargs)
with RecordFunctions():
out = fn_compiled(inps)
self.assertTrue(conv_seen)
@patch("torch.cuda.is_available", lambda: False)
def test_conv2d_bn_mixed_dtype(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(
3,
16,
kernel_size=3,
stride=1,
padding=1,
bias=False,
dtype=torch.bfloat16,
)
self.bn = torch.nn.BatchNorm2d(
16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
v = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16)
mod = Model().eval()
with torch.no_grad():
self.common(
mod,
(v,),
)
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
@patch("torch.cuda.is_available", lambda: False)
def test_conv2d_packed(self):
options = itertools.product([[3, 56, 56]], [True, False], [0, (0,)])
for x_shape, mode_train, padding in options:
mod = torch.nn.Sequential(
torch.nn.Conv2d(3, 64, 3, 3, padding=padding)
).train(mode=mode_train)
v = torch.randn(x_shape, dtype=torch.float32)
with torch.no_grad():
self.common(
mod,
(v,),
)
@patch("torch.cuda.is_available", lambda: False)
def test_conv2d_autocast(self):
v = torch.randn(1, 3, 28, 18, dtype=torch.float32)
mod = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 3, 3)).eval()
with torch.no_grad(), torch.cpu.amp.autocast():
self.common(
mod,
(v,),
)
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
@patch("torch.cuda.is_available", lambda: False)
def test_unsupported_conv_transpose(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv_transpose = torch.nn.ConvTranspose2d(
3, 6, 3, stride=1, padding=1, output_padding=1
)
def forward(self, input_tensor):
x = self.conv_transpose(input_tensor)
output = torch.tanh(x)
return output
input = torch.randn(1, 3, 28, 28)
m = Model().eval()
with torch.no_grad():
compiled_m = torch.compile(m)
with self.assertRaisesRegex(
RuntimeError,
"output padding must be smaller than either stride or dilation",
):
compiled_m(input)
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
@patch("torch.cuda.is_available", lambda: False)
def test_conv_used_from_multiple_places(self):
class M(torch.nn.Module):
def __init__(self, conv_in_channel, conv_out_channel) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(conv_in_channel, conv_out_channel, (3, 3))
def forward(self, x):
res = self.conv(x)
res = F.relu(res)
res = self.conv(res)
return res
with torch.no_grad():
mod = M(3, 3).eval()
x = torch.randn(1, 3, 224, 224)
self.common(
mod,
(x,),
)
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
@patch("torch.cuda.is_available", lambda: False)
def test_linear_used_from_multiple_places(self):
class M(torch.nn.Module):
def __init__(self, in_channel, out_channel) -> None:
super().__init__()
self.linear = torch.nn.Linear(in_channel, out_channel)
def forward(self, x):
res = self.linear(x)
res = F.relu(res)
res = self.linear(res)
return res
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
with torch.no_grad():
m = M(224, 224).bfloat16().eval()
m_opt = torch.compile(m)
x = torch.randn(224, 224, dtype=torch.bfloat16)
m_opt(x)
self.assertEqual(m(x), m_opt(x))
@config.patch(implicit_fallbacks=True)
def test_multihead_attention_cpu(self):
def fn(
q,
k,
v,
embed_dim,
num_heads,
qkv_weight,
qkv_bias,
proj_weight,
proj_bias,
mask,
need_weights,
):
return torch._native_multi_head_attention(
q,
k,
v,
embed_dim,
num_heads,
qkv_weight,
qkv_bias,
proj_weight,
proj_bias,
mask,
need_weights,
)
B = 1
T = 3
embed_dim = 6
num_heads = 2
q = torch.randn([B, T, embed_dim])
k = torch.randn([B, T, embed_dim])
v = torch.randn([B, T, embed_dim])
qkv_weight = torch.randn([3 * embed_dim, embed_dim])
qkv_bias = torch.randn([3 * embed_dim])
proj_weight = torch.randn([3 * embed_dim, embed_dim])
proj_bias = torch.randn([3 * embed_dim])
mask = None
need_weights = False
inps = [
q,
k,
v,
embed_dim,
num_heads,
qkv_weight,
qkv_bias,
proj_weight,
proj_bias,
mask,
need_weights,
]
self.common(fn, inps)
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
@patch("torch.cuda.is_available", lambda: False)
def test_linear_packed(self):
options = itertools.product(
[[2, 3, 10], [2, 10], [10], [2, 0]], [3, 0], [True, False]
)
for input_shape, out_dim, bias in options:
mod = torch.nn.Sequential(
torch.nn.Linear(input_shape[-1], out_dim, bias=bias)
).eval()
v = torch.randn(input_shape)
with torch.no_grad():
self.common(
mod,
(v,),
)
if torch.ops.mkldnn._is_mkldnn_bf16_supported() and len(input_shape) > 1:
mod = mod.to(torch.bfloat16)
v = v.to(torch.bfloat16)
with torch.no_grad():
self.common(
mod,
(v,),
)
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
@patch("torch.cuda.is_available", lambda: False)
def test_conv_transpose2d_packed_cpu(self):
options = itertools.product([[1, 3, 28, 28], [3, 28, 28]], [0, (0,)])
for x_shape, padding in options:
mod = torch.nn.Sequential(
torch.nn.ConvTranspose2d(3, 64, 3, 3, padding=padding)
).eval()
v = torch.randn(x_shape, dtype=torch.float32)
with torch.no_grad():
self.common(
mod,
(v,),
)
@unittest.skipIf(not torch._C._has_mkldnn, "MKLDNN is not enabled")
@patch("torch.cuda.is_available", lambda: False)
@torch._dynamo.config.patch(dynamic_shapes=True)
@torch._dynamo.config.patch(assume_static_by_default=False)
@torch._dynamo.config.patch(allow_rnn=True)
@config.patch(freezing=True)
def _test_lstm_packed(self, params_dict, change_input_sizes=False):
from torch._dynamo.utils import counters
for (
unbatched,
input_size,
hidden_size,
num_layers,
bidirectional,
bias,
empty_state,
batch_first,
batch_size,
seq_len,
) in itertools.product(*list(params_dict.values())):
dtypes = [torch.float]
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
dtypes.append(torch.bfloat16)
for dtype in dtypes:
counters.clear()
num_directions = 2 if bidirectional else 1
seq_len_var = seq_len + 3
if unbatched:
v = torch.randn(seq_len, input_size)
v_var = torch.randn(seq_len_var, input_size)
h = torch.randn(num_layers * num_directions, hidden_size)
c = torch.randn(num_layers * num_directions, hidden_size)
else:
if batch_first:
v = torch.randn(batch_size, seq_len, input_size)
v_var = torch.randn(batch_size, seq_len_var, input_size)
else:
v = torch.randn(seq_len, batch_size, input_size)
v_var = torch.randn(seq_len_var, batch_size, input_size)
h = torch.randn(
num_layers * num_directions, batch_size, hidden_size
)
c = torch.randn(
num_layers * num_directions, batch_size, hidden_size
)
mod = LstmModule(
input_size,
hidden_size,
num_layers,
bias,
bidirectional,
batch_first,
).eval()
maybe_autocast = (
torch.cpu.amp.autocast()
if dtype == torch.bfloat16
else contextlib.nullcontext()
)
with torch.no_grad(), maybe_autocast:
inps = [v]
if not empty_state:
inps.append((h, c))
fn_opt = torch._dynamo.optimize("inductor")(mod)
code = run_and_get_cpp_code(fn_opt, *inps)
# Check that _flat_weights are not functional_tensor, otherwise
# deepcopy will fail during recompilation.
fn_opt_copy = copy.deepcopy(fn_opt)
_flat_weights = fn_opt_copy.lstm._flat_weights
for _flat_weight in _flat_weights:
self.assertFalse(torch._is_functional_tensor(_flat_weight))
self.assertTrue("aten.mkldnn_rnn_layer" in code)
self.assertEqual(fn_opt(*inps), mod(*inps))
self.assertEqual(
counters["inductor"]["pattern_matcher_count"],
num_layers * num_directions
+ 2, # num of mkldnn_rnn_layer call + 2 view call on the concatenated hy, cy.
)
# Change input sizes
if change_input_sizes:
inps_var = [v_var]
self.assertEqual(fn_opt(*inps_var), mod(*inps_var))
@slowTest
def test_lstm_packed(self):
params_dict = {
"unbatched": [True, False],
"input_size": [1, 2],
"hidden_size": [5, 32],
"num_layers": [1, 3],
"bidirectional": [False, True],
"bias": [False, True],
"empty_state": [False, True],
"batch_first": [True, False],
"batch_size": [1, 2],
"seq_len": [1, 3],
}
self._test_lstm_packed(params_dict)
def test_lstm_packed_change_input_sizes(self):
params_dict = {
"unbatched": [False],
"input_size": [2],
"hidden_size": [5],
"num_layers": [3],
"bidirectional": [True],
"bias": [True],
"empty_state": [False],
"batch_first": [False],
"batch_size": [2],
"seq_len": [3],
}
self._test_lstm_packed(params_dict, change_input_sizes=True)
@torch._dynamo.config.patch(dynamic_shapes=True)
@torch._dynamo.config.patch(assume_static_by_default=False)
@torch._dynamo.config.patch(allow_rnn=True)
def test_pack_padded_sequence_lstm(self):
embedding_dim = 12
hidden_dim = 10
batch_size = 24
num_layers = 1
bidirectional = True
num_direc = 2
max_lens = 96
sent = torch.randn(batch_size, max_lens, embedding_dim)
hid_0 = torch.rand(num_layers * num_direc, batch_size, hidden_dim)
hid_1 = torch.randn(num_layers * num_direc, batch_size, hidden_dim)
sent_lens = torch.Tensor(
[1, 2, 3, 4, 5, 1, 3, 2, 96, 5, 3, 1, 1, 2, 1, 2, 3, 6, 1, 2, 4, 6, 2, 1]
)
assert sent_lens.shape[0] == batch_size
assert sent_lens.max().item() == max_lens
hidden_0 = hid_0.clone().requires_grad_(False)
hidden_1 = hid_1.clone().requires_grad_(False)
embeds = torch.nn.utils.rnn.pack_padded_sequence(
sent, sent_lens, batch_first=True, enforce_sorted=False
)
mod = LstmModule(
embedding_dim,
hidden_dim,
num_layers=num_layers,
bias=True,
bidirectional=bidirectional,
batch_first=True,
).eval()
with torch.no_grad():
inps = [embeds, (hidden_0, hidden_1)]
fn_opt = torch._dynamo.optimize("inductor")(mod)
code = run_and_get_cpp_code(fn_opt, *inps)
# This case is unsupported
self.assertFalse("torch.ops.mkldnn._lstm" in code)
self.assertEqual(fn_opt(*inps), mod(*inps))
@patch("torch.cuda.is_available", lambda: False)
def test_conv_transpose2d_has_output_size_input(self):
# https://github.com/pytorch/pytorch/issues/100344.
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv_transpose = torch.nn.ConvTranspose2d(
in_channels=3, out_channels=1, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
return self.conv_transpose(x, output_size=(10, 10))
mod = M().eval()
v = torch.randn(1, 3, 10, 10, dtype=torch.float32)
with torch.no_grad():
self.common(
mod,
(v,),
)
def test_pad_with_nan_value(self):
# https://github.com/pytorch/pytorch/issues/100988.
class Model(torch.nn.Module):
def forward(self, x):
x = F.pad(x, (1, 1, 1, 1), value=float("nan"))
return x
mod = Model().eval()
v = torch.randn(1, 3, 10, 10, dtype=torch.float32)
with torch.no_grad():
self.common(
mod,
(v,),
)
def test_masked_fill_with_inf_or_nan_value(self):
def fn(value, mask):
y1 = torch.masked_fill(value, mask, float("inf"))
y2 = torch.masked_fill(value, mask, float("-inf"))
y3 = torch.masked_fill(value, mask, float("nan"))
return y1, y2, y3
value = torch.randn((2, 17))
mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8).to(torch.bool)
with torch.no_grad():
self.common(
fn,
(value, mask),
)
@config.patch(implicit_fallbacks=True)
def test_repeat_interleave(self):
def fn(y):
return torch.repeat_interleave(y, 2, output_size=8)
a = torch.tensor([[1, 2], [3, 4]])
self.common(
fn,
(a,),
)
def test_inplace_squeeze_needed(self):
mod = torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.LayerNorm(10),
torch.nn.ReLU(),
).eval()
def fn(x):
return mod(x)
v = torch.randn(10)
# TODO: OMP parallel reduction order is not deterministic.
# Hence, the accurarcy might vary up and down. For short term,
# we increase the tolerance and will fix it later by using
# aten parallel.
self.common(fn, (v,), atol=5e-1, rtol=5e-1)
def test_cat_mul(self):
# https://github.com/pytorch/pytorch/issues/93365
def fn(p0, p1):
y1 = torch.cat([p0, p1], dim=0)
y2 = torch.mul(y1, y1)
return y1, y2
p0 = torch.randn(3, 4)
p1 = torch.randn(3, 4)
self.common(fn, (p0, p1))
def test_pow_cos(self):
# https://github.com/pytorch/pytorch/issues/98149
def fn(x):
t = x.pow(5)
return torch.cos(t)
x = torch.tensor([4], dtype=torch.uint8)
self.common(fn, (x,))
def test_reduce_with_masked(self):
# https://github.com/pytorch/pytorch/issues/96484
def fn(a, b):
a = torch.nn.functional.pad(a, (0, -1))
c = a + b
return c.min(0).values
a = torch.randn([2])
b = torch.randn([2])
self.common(fn, (a, b))
def test_scalar_sign_with_min(self):
# https://github.com/pytorch/pytorch/issues/101340
def fn(a):
t1 = torch.tanh(a)
t2 = torch.sign(t1)
return torch.min(t1, t2)
a = torch.randn(1, 3)
self.common(fn, (a,))
def test_index_propagation_issue_102065(self):
def fn(x):
x = torch.arange(x.numel())
return (x.unsqueeze(0) - x.unsqueeze(1)) ** 2
self.common(
fn,
(torch.randn(8),),
)
def test_ModularIndexing_range_issue_103133(self):
def fn(q, k):
einsum = torch.einsum("bcxd,bcyd->bcxy", (q, k))
constant_pad_nd = torch.ops.aten.constant_pad_nd.default(
einsum, [0, 0, 0, 1], 0.0
)
view = torch.ops.aten.view.default(constant_pad_nd, [12, 1, 512, 513])
y = view.new_zeros((12, 2, 256, 513))
y[:, :-1, :, 256:] = view[:, :, :256, :257]
return y
self.common(
fn,
(
torch.empty_strided((12, 1, 512, 64), (64, 196608, 768, 1)),
torch.empty_strided((12, 1, 512, 64), (64, 196608, 768, 1)),
),
)
@patch("torch.cuda.is_available", lambda: False)
def test_max_reduction_lowp_fp(self):
def fn(x):
return torch.ops.aten.max(x, 1, keepdim=True)[0].float()
for dtype in _lowp_fp_dtypes:
self.common(
fn,
(torch.randn(1, 32, 4, 4).to(dtype),),
)
@patch("torch.cuda.is_available", lambda: False)
def test_vec_transpose_lowp_fp(self):
for dtype in _lowp_fp_dtypes:
def fn(x):
return x.to(memory_format=torch.channels_last).to(dtype)
self.common(
fn,
(torch.randn(2, 3, 4, 4),),
)
def test_load_inf_bf16(self):
def fn1(x):
return torch.where(x > 0, x, math.inf)
def fn2(x):
return torch.where(x > 0, x, -math.inf)
for fn in [fn1, fn2]:
self.common(
fn,
(torch.randn(1, 3, 16, 16),),
)
@patch("torch.cuda.is_available", lambda: False)
def test_fp32_load_with_to_lowp_fp(self):
# From llama model.
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.cache_k = torch.zeros(8, 4, 2, 2)
def forward(self, x, xk):
bsz, seqlen, _ = x.shape
self.cache_k = self.cache_k.to(x)
self.cache_k[:bsz, 1 : 1 + seqlen] = xk
return self.cache_k
for dtype in _lowp_fp_dtypes:
ref_model = Model().eval()
opt_model = torch.compile()(Model().eval())
x = torch.randn(4, 2, 2).to(dtype)
xk = torch.randn(4, 2, 2, 2).to(dtype)
self.assertEqual(opt_model(x, xk), ref_model(x, xk))
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_sigmoid_with_reduction(self):
def fn(x):
x = torch.ops.aten.sigmoid.default(x)
return torch.ops.aten.mean.dim(x, [-1, -2], True)
x = torch.randn((1, 8, 8, 8))
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x,))
def test_slice_scatter_default_end_value(self):
# From HF AllenaiLongformerBase.
def fn(query, key, window_overlap):
batch_size, seq_len, num_heads, head_dim = query.size()
assert (
seq_len % (window_overlap * 2) == 0
), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}"
chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
diagonal_chunked_attention_scores = key
diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros(
(
batch_size * num_heads,
chunks_count + 1,
window_overlap,
window_overlap * 2 + 1,
)
)
diagonal_attention_scores[
:, :3, :, window_overlap:
] = diagonal_chunked_attention_scores[
:, :, :window_overlap, : window_overlap + 1
]
return diagonal_attention_scores
self.common(
fn,
(
torch.randn(1, 1024, 12, 64),
torch.randn(12, 3, 512, 513),
256,
),
)
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_to_uint8_rounding_method(self):
def fn(x):
return x.to(torch.uint8)
numerical_testsuit = [4.4, 4.5, 4.6, 5.5]
for numerical_number in numerical_testsuit:
x = torch.ones(17) * numerical_number
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_decomposed_dequant_relu_quant(self):
def fn(x, scale, zero_point, use_dequant, use_quant):
# For quantized_decomposed.dequantize_per_tensor
# Refer to torch/ao/quantization/fx/_decomposed.py
if use_dequant:
x = (x.to(torch.float32) - zero_point) * scale
x = torch.relu(x)
# For quantized_decomposed.quantize_per_tensor
# Refer to torch/ao/quantization/fx/_decomposed.py
if use_quant:
inv_scale = 1.0 / scale
x = torch.clamp(torch.round(x * inv_scale) + zero_point, 0, 255).to(
torch.uint8
)
return x
use_dequant_list = [False, True]
use_quant_list = [False, True]
for use_dequant, use_quant in itertools.product(
use_dequant_list, use_quant_list
):
x = torch.clamp(
torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, 0, 255
)
if use_dequant:
x = x.to(torch.uint8)
zero_point = 100
scale = 0.01
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x, scale, zero_point, use_dequant, use_quant))
assert metrics.generated_cpp_vec_kernel_count == 1
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_dequant_quant_lowering(self):
def fn(x, scale, zero_point, use_dequant, use_quant):
if use_dequant:
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, scale, zero_point, 0, 255, torch.uint8
)
x = torch.relu(x)
if use_quant:
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, scale, zero_point, 0, 255, torch.uint8
)
return x
use_dequant_list = [False, True]
use_quant_list = [False, True]
use_tensor_overload_list = [False, True]
for use_dequant, use_quant, use_tensor_overload in itertools.product(
use_dequant_list, use_quant_list, use_tensor_overload_list
):
x = torch.clamp(
torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, 0, 255
)
if use_dequant:
x = x.to(torch.uint8)
zero_point = 100
scale = 0.01
if use_tensor_overload:
zero_point = torch.tensor(zero_point, dtype=torch.int64)
scale = torch.tensor(scale)
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x, scale, zero_point, use_dequant, use_quant))
assert metrics.generated_cpp_vec_kernel_count == 1
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_dequant_maxpool2d_lowering(self):
def fn(x, scale, zero_point):
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, scale, zero_point, 0, 255, torch.uint8
)
max_pool2d_with_indices_default = (
torch.ops.aten.max_pool2d_with_indices.default(
x, [2, 2], [2, 2], [1, 1]
)[0]
)
return max_pool2d_with_indices_default
use_tensor_overload_list = [False, True]
for use_tensor_overload in use_tensor_overload_list:
x = (
torch.clamp(
torch.randn((3, 16, 8, 8), dtype=torch.float32) * 100, 0, 255
)
.to(torch.uint8)
.contiguous(memory_format=torch.channels_last)
)
zero_point = 100
scale = 0.01
if use_tensor_overload:
zero_point = torch.tensor(zero_point, dtype=torch.int64)
scale = torch.tensor(scale)
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x, scale, zero_point))
assert metrics.generated_cpp_vec_kernel_count == 1
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_tile2d_load_decomposed_dequant_add_relu_quant(self):
def fn(
x,
scale,
zero_point,
x2,
scale2,
zero_point2,
output_scale,
output_zero_point,
use_dequant,
use_dequant2,
use_quant,
):
if use_dequant:
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, scale, zero_point, 0, 255, torch.uint8
)
if use_dequant2:
x2 = torch.ops.quantized_decomposed.dequantize_per_tensor(
x2, scale2, zero_point2, 0, 255, torch.uint8
)
temp = x + x2
y = torch.relu(temp)
if use_quant:
y = torch.ops.quantized_decomposed.quantize_per_tensor(
y, output_scale, output_zero_point, 0, 255, torch.uint8
)
return y.contiguous()
use_dequant_list = [False, True]
use_dequant_list2 = [False, True]
use_quant_list = [False, True]
for use_dequant, use_dequant2, use_quant in itertools.product(
use_dequant_list, use_dequant_list2, use_quant_list
):
x = torch.clamp(
torch.randn((1, 1024, 14, 14), dtype=torch.float32) * 100, 0, 255
).contiguous(memory_format=torch.channels_last)
x2 = torch.clamp(
torch.randn((1, 1024, 14, 14), dtype=torch.float32) * 100, 0, 255
).contiguous(memory_format=torch.channels_last)
if use_dequant:
x = x.to(torch.uint8).contiguous(memory_format=torch.channels_last)
if use_dequant2:
x2 = x2.to(torch.uint8).contiguous(memory_format=torch.channels_last)
zero_point = 1
scale = 0.01
zero_point2 = 2
scale2 = 0.02
output_zero_point = 3
output_scale = 0.03
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
self.common(
fn,
(
x,
scale,
zero_point,
x2,
scale2,
zero_point2,
output_scale,
output_zero_point,
use_dequant,
use_dequant2,
use_quant,
),
)
assert metrics.generated_cpp_vec_kernel_count == 2
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_non_contiguous_load_buf_quant(self):
def fn(
x1,
x2,
groups,
):
x = torch.cat((x1, x2), dim=1)
batchsize, num_channels, height, width = x.size()
channels_per_group = num_channels // groups
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, 1.0, 0, 0, 255, torch.uint8
)
x = x.view(batchsize, groups, channels_per_group, height, width)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, 1.0, 0, 0, 255, torch.uint8
)
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, 1.0, 0, 0, 255, torch.uint8
)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batchsize, num_channels, height, width)
return x
x = torch.randint(0, 8, (1, 116, 28, 28), dtype=torch.uint8).contiguous(
memory_format=torch.channels_last
)
x2 = torch.randint(0, 8, (1, 116, 28, 28), dtype=torch.uint8).contiguous(
memory_format=torch.channels_last
)
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
self.common(
fn,
(
x,
x2,
2,
),
)
assert metrics.generated_cpp_vec_kernel_count == 1
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_tile2d_store_channel_shuffle_cl_quant_output(self):
def channel_shuffle(x, groups, output_scale, output_zero_point):
batchsize, num_channels, height, width = x.size()
channels_per_group = num_channels // groups
x = x.view(batchsize, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batchsize, -1, height, width)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, output_scale, output_zero_point, 0, 255, torch.uint8
)
return x.contiguous(memory_format=torch.channels_last)
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
x = torch.randn(64, 58, 28, 28)
output_zero_point = 3
output_scale = 0.03
self.common(channel_shuffle, (x, 2, output_scale, output_zero_point))
assert metrics.generated_cpp_vec_kernel_count == 2
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_dequant_relu_quant_dequant_relu_quant_lowering(self):
def fn(x, scale, zero_point, scale2, zero_point2, scale3, zero_point3):
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, scale, zero_point, 0, 255, torch.uint8
)
x = torch.relu(x)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, scale2, zero_point2, 0, 255, torch.uint8
)
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, scale2, zero_point2, 0, 255, torch.uint8
)
x = torch.relu(x)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, scale3, zero_point3, 0, 255, torch.uint8
)
return x
for use_tensor_overload in [True, False]:
x = torch.clamp(
torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, 0, 255
).to(torch.uint8)
zero_point_list = [100, 101, 102]
scale_list = [0.01, 0.02, 0.03]
if use_tensor_overload:
for i in range(len(zero_point_list)):
zero_point_list[i] = torch.tensor(
zero_point_list[i], dtype=torch.int64
)
scale_list[i] = torch.tensor(scale_list[i])
zero_point, zero_point2, zero_point3 = zero_point_list
scale, scale2, scale3 = scale_list
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
self.common(
fn,
(x, scale, zero_point, scale2, zero_point2, scale3, zero_point3),
rtol=1e-2,
atol=1e-2,
)
assert metrics.generated_cpp_vec_kernel_count == 1
def test_inplace_add_alpha(self):
def fn(x, y):
aten.add_.Tensor(x, y, alpha=0.55)
return (x,)
x1 = torch.zeros(10)
x2 = torch.zeros(10)
x3 = torch.zeros(10)
y = torch.randn(10)
fn_fx = make_fx(fn)(x1, y)
fn_compiled = compile_fx_inner(fn_fx, [x1, y])
fn(x2, y)
fn_compiled([x3, y])
assert same(x2, x3)
def test_int_div(self):
def fn(x, y):
s3 = x.size(1)
a = torch.zeros((1 + s3) // 2)
a += y
return a, s3
p0 = torch.randint(5, (1, 8))
p1 = torch.randn(1)
self.common(fn, (p0, p1))
def test_no_op_squeeze(self):
@torch._dynamo.optimize("inductor")
def forward(arg0_1):
return torch.ops.aten.squeeze.dim(arg0_1, 1)
x = torch.randn((10, 20))
self.common(forward, (x,))
def test_parallel_num_threads(self):
@torch._dynamo.optimize("inductor")
def fn(x1, x2):
return x1 + x2
@contextlib.contextmanager
def set_num_threads(num_threads):
orig_num_threads = torch.get_num_threads()
torch.set_num_threads(num_threads)
yield
torch.set_num_threads(orig_num_threads)
x1 = torch.randn((10, 20))
x2 = torch.randn((10, 20))
with set_num_threads(1):
assert same(x1 + x2, fn(x1, x2))
with set_num_threads(4):
assert same(x1 + x2, fn(x1, x2))
@patch("torch.cuda.is_available", lambda: False)
def test_timed_cpu_only(self):
timed(lambda: torch.randn(10), ())
def test_complex_memory_overlap(self):
dense = torch.zeros(64, 32)
self.assertFalse(complex_memory_overlap(dense))
self.assertFalse(complex_memory_overlap(dense.t()))
strided = dense.split(4, dim=1)
self.assertFalse(complex_memory_overlap(strided[0]))
self.assertFalse(complex_memory_overlap(strided[0].t()))
unsqueezed = dense.unsqueeze(1)
self.assertFalse(complex_memory_overlap(unsqueezed))
self.assertFalse(complex_memory_overlap(unsqueezed.permute(1, 2, 0)))
gathered = dense.index_select(0, torch.IntTensor([1, 0, 1]))
self.assertFalse(complex_memory_overlap(gathered))
self.assertFalse(complex_memory_overlap(gathered.t()))
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
def test_vec_dynamic_shapes(self):
def fn(x):
return torch.softmax(x, -1)
value = torch.randn((2, 10))
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (value,))
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_auto_simd(self):
vec_avx512 = codecache.supported_vec_isa_list[0]
vec_avx2 = codecache.supported_vec_isa_list[1]
self.assertTrue(vec_avx512.bit_width() == 512)
self.assertTrue(vec_avx2.bit_width() == 256)
self.assertTrue(vec_avx512.nelements() == 16)
self.assertTrue(vec_avx2.nelements() == 8)
self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32)
self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16)
with config.patch({"cpp.simdlen": None}):
isa = codecache.pick_vec_isa()
if vec_avx512 in codecache.valid_vec_isa_list():
self.assertTrue(isa == vec_avx512)
else:
self.assertTrue(isa == vec_avx2)
with config.patch({"cpp.simdlen": 0}):
isa = codecache.pick_vec_isa()
self.assertFalse(isa)
with config.patch({"cpp.simdlen": 1}):
isa = codecache.pick_vec_isa()
self.assertFalse(isa)
with config.patch({"cpp.simdlen": 257}):
isa = codecache.pick_vec_isa()
self.assertFalse(isa)
with config.patch({"cpp.simdlen": 513}):
isa_list = codecache.valid_vec_isa_list()
if vec_avx512 in isa_list:
self.assertFalse(isa)
with config.patch({"cpp.simdlen": 512}):
isa_list = codecache.valid_vec_isa_list()
if vec_avx512 in isa_list:
isa = codecache.pick_vec_isa()
self.assertTrue(isa == vec_avx512)
with config.patch({"cpp.simdlen": 256}):
isa_list = codecache.valid_vec_isa_list()
if vec_avx2 in isa_list:
isa = codecache.pick_vec_isa()
self.assertTrue(isa == vec_avx2)
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_masked_fill_softmax(self):
def fn(value, mask):
mask = mask.to(torch.bool)
x = torch.masked_fill(value, mask, -33.0)
return torch.softmax(x, -1)
for dtype in vec_dtypes:
value = torch.randn((2, 17), dtype=dtype)
mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8)
with config.patch({"cpp.simdlen": None}):
for cpp_wrapper_flag in [True, False]:
with config.patch({"cpp_wrapper": cpp_wrapper_flag}):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (value, mask))
assert metrics.generated_cpp_vec_kernel_count >= 1
def test_load_same_bool_tensor_twice(self):
@torch._dynamo.optimize("inductor")
def fn(a, b):
x = torch.masked_fill(a, b, -33.0)
y = torch.masked_fill(a, b, -33.0)
return x, y
value = torch.randn((2, 17))
mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8).to(torch.bool)
fn(value, mask)
def test_cpu_vec_cosim(self):
cpp_vec_op_list = []
cpp_op_list = []
for k, v in CppVecOverrides.__dict__.items():
if isinstance(v, staticmethod):
cpp_vec_op_list.append(k)
for k, v in CppOverrides.__dict__.items():
if isinstance(v, staticmethod):
cpp_op_list.append(k)
diff = [
"index_expr",
"signbit",
"isinf",
"mod",
"masked",
"randn",
"isnan",
"rand",
"randint64",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"bitwise_and",
"bitwise_left_shift",
"bitwise_not",
"bitwise_right_shift",
"bitwise_or",
"bitwise_xor",
"to_dtype_bitcast",
]
union = {*cpp_vec_op_list, *diff}
self.assertTrue(
set(cpp_op_list).issubset(union), f"unexpected: {set(cpp_op_list) - union}"
)
def test_atomic_add_lowp_fp(self):
def fn(test_args):
res = torch.gather(**test_args)
return res
for dtype in _lowp_fp_dtypes:
input_tensor_for_ref = torch.tensor(
[[3.0, -5.0]], dtype=dtype, requires_grad=True
)
input_tensor_for_opt = torch.tensor(
[[3.0, -5.0]], dtype=dtype, requires_grad=True
)
test_args_for_ref = {
"input": input_tensor_for_ref,
"dim": 1,
"index": torch.tensor([[1]]),
}
test_args_for_opt = {
"input": input_tensor_for_opt,
"dim": 1,
"index": torch.tensor([[1]]),
}
opt_fn = torch.compile(fn)
ref_fwd = fn(test_args_for_ref)
res_fwd = opt_fn(test_args_for_opt)
self.assertEqual(res_fwd, ref_fwd)
torch.manual_seed(1)
bwd_tensor_for_ref = torch.randn(ref_fwd.shape, dtype=dtype)
torch.manual_seed(1)
bwd_tensor_for_opt = torch.randn(res_fwd.shape, dtype=dtype)
self.assertEqual(bwd_tensor_for_ref, bwd_tensor_for_opt)
ref_fwd.backward(bwd_tensor_for_ref)
res_fwd.backward(bwd_tensor_for_opt)
ref_grad = test_args_for_ref["input"].grad
res_grad = test_args_for_opt["input"].grad
self.assertEqual(ref_grad, res_grad)
@patch("torch.cuda.is_available", lambda: False)
def test_scatter_using_atomic_add(self):
def fn(a, dim, index, b):
return aten.scatter(a, dim, index, b, reduce="add")
inps = (
torch.randn(5, 29, 13),
2,
torch.tensor([[[3, 5, 7, 9]]]),
torch.randn(1, 1, 10),
)
fn_opt = torch.compile()(fn)
with config.patch({"cpp.fallback_scatter_reduce_sum": False}):
code = run_and_get_cpp_code(fn_opt, *inps)
FileCheck().check("atomic_add").run(code)
self.assertEqual(
fn(*inps),
fn_opt(*inps),
)
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_new_vec_op_cpu_only(self):
def fn(x):
return torch.log1p(torch.expm1(torch.erf(x)))
for dtype in vec_dtypes:
torch.manual_seed(0)
x = torch.randn((2, 9), dtype=dtype)
x[0, 0] = torch.nan
x[1, -1] = torch.nan
tol = 1e-2 if dtype == torch.bfloat16 else 1e-4
with config.patch({"cpp.simdlen": None}):
for cpp_wrapper_flag in [True, False]:
with config.patch({"cpp_wrapper": cpp_wrapper_flag}):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_vec_cpu_only_for_all_available_isa(self):
def fn(x):
return torch.sin(torch.cos(torch.erf(x)))
x = torch.randn((2, 9))
x[0, 0] = torch.nan
x[1, -1] = torch.nan
bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()] + [None]
for item in bit_widths:
with config.patch({"cpp.simdlen": item}):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
@slowTest
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test__adaptive_avg_pool2d(self):
def wrap_fn(oh, ow):
def fn(x):
return torch._adaptive_avg_pool2d(x, (oh, ow))
return fn
bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()]
ih = [16, 65]
iw = ih
oh = ih
ow = ih
for _ih, _iw, _oh, _ow, _simd_len, dtype in itertools.product(
ih, iw, oh, ow, bit_widths, vec_dtypes
):
x = torch.randn(2, 3, _ih, _iw, dtype=dtype).to(
memory_format=torch.channels_last
)
_fn = wrap_fn(_oh, _ow)
with config.patch({"cpp.simdlen": _simd_len}):
torch._dynamo.reset()
metrics.reset()
self.common(_fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_vec_logical(self):
def wrap_fn1(op: Callable):
def fn(x: torch.Tensor):
return torch.where(op(x), 1.0, 0.0)
return fn
def wrap_fn2(op: Callable):
def fn(x: torch.Tensor, y: torch.Tensor):
return torch.where(op(x, y), 1.0, 0.0)
return fn
for dtype in vec_dtypes:
x = torch.randn(64, dtype=dtype)
y = torch.randn(64, dtype=dtype)
logical_fns = [
torch.logical_and,
torch.logical_not,
torch.logical_or,
torch.logical_xor,
]
for logical_fn in logical_fns:
torch._dynamo.reset()
metrics.reset()
if logical_fn == torch.logical_not:
_fn = wrap_fn1(logical_fn)
_args = (x,)
else:
_fn = wrap_fn2(logical_fn)
_args = (x, y)
self.common(_fn, _args)
assert metrics.generated_cpp_vec_kernel_count == 1
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_vec_compare_op_cpu_only(self):
def fn(x):
y1 = torch.eq(x, 1.0)
x = torch.where(y1, x, -x)
y2 = torch.ne(x, 0.0)
x = torch.where(y2, x, -x)
y3 = torch.lt(x, 5.0)
x = torch.where(y3, x, x - 1.0)
y4 = torch.gt(x, -2.0)
x = torch.where(y4, x, x + 1.0)
y5 = torch.le(x, 8.0)
x = torch.where(y5, x, x - 1.0)
y6 = torch.ge(x, -3.0)
x = torch.where(y6, x, x + 1.0)
y7 = x == 1.0
x = torch.where(y7, x, -x)
y8 = x != 0.0
x = torch.where(y8, x, -x)
y9 = x < 5.0
x = torch.where(y9, x, x - 1.0)
y10 = x > -2.0
x = torch.where(y10, x, x + 1.0)
y11 = x <= 8.0
x = torch.where(y11, x, x - 1.0)
y12 = x >= -3.0
x = torch.where(y12, x, x + 1.0)
return x
for dtype in vec_dtypes:
x = torch.randn((2, 9), dtype=dtype)
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
assert (
metrics.generated_kernel_count
- metrics.generated_cpp_vec_kernel_count
) == 0
def test_skip_cpp_codegen(self):
with config.patch({"disable_cpp_codegen": True}):
inps = torch.ones([20]), torch.rand([20])
def f(x, y):
return x + y + torch.tensor(1)
f_opt = torch.compile()(f)
code = run_and_get_cpp_code(f_opt, inps[0], inps[1])
FileCheck().check_not("void kernel").run(code)
self.assertEqual(
f(*inps),
f_opt(*inps),
)
# constant needs to be propagated on fallback
def f(x):
return x[torch.tensor(1) :] * 2
f_opt = torch.compile()(f)
code = run_and_get_cpp_code(f_opt, inps[0])
FileCheck().check_not("void kernel").run(code)
self.assertEqual(f_opt(inps[0]), f(inps[0]))
class Model(torch.nn.Module):
def __init__(
self,
):
super().__init__()
def forward(self, v1: torch.Tensor):
vx = v1.min(dim=1).values
v2 = torch.randn_like(vx)
return v2
model = Model()
x = torch.rand(10, 3, 0)
model_f = torch.compile()(model)
self.assertEqual(model(x), model_f(x))
def test_redundant_to_node_elimination_lowp_fp(self):
def fn(x, y):
res = x + y
res = torch.mean(res)
return res
for dtype in _lowp_fp_dtypes:
x = torch.randn((2, 9), dtype=dtype)
y = torch.randn((2, 9), dtype=dtype)
for torch_compile_debug in [True, False]:
with config.patch(
{"trace.enabled": torch_compile_debug, "cpp.simdlen": None}
):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x, y))
if codecache.valid_vec_isa_list():
assert metrics.generated_cpp_vec_kernel_count == 1
def test_do_not_insert_to_dtype_for_memory_copy_only_kernel(self):
def fn(x):
res = x.clone()
return res
x = torch.randn((100, 100), dtype=torch.bfloat16)
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x,))
assert metrics.cpp_to_dtype_count == 0
if codecache.valid_vec_isa_list():
assert metrics.generated_cpp_vec_kernel_count == 1
def test_insert_to_dtype_count(self):
def fn(x):
res = x.relu()
return res
x = torch.randn((100, 100), dtype=torch.bfloat16)
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x,))
assert metrics.cpp_to_dtype_count == 2
if codecache.valid_vec_isa_list():
assert metrics.generated_cpp_vec_kernel_count == 1
def test_memory_copy_with_fusion(self):
def fn(x):
res = x.relu()
x.copy_(res)
return (res,)
x = torch.randn((100, 100), dtype=torch.bfloat16)
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x,))
assert metrics.cpp_to_dtype_count == 2
if codecache.valid_vec_isa_list():
assert metrics.generated_cpp_vec_kernel_count == 1
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_cpp_vec_constant_checker(self):
_graph: torch.fx.Graph = torch.fx.Graph()
a: torch.fx.Node = _graph.create_node("placeholder", "ops")
iv: torch.fx.Node = _graph.create_node("placeholder", "iv")
fv: torch.fx.Node = _graph.create_node("placeholder", "fv")
b: torch.fx.Node = _graph.create_node(
"call_method",
"constant",
args=(
a,
iv,
torch.int64,
),
)
c: torch.fx.Node = _graph.create_node(
"call_method",
"constant",
args=(
a,
fv,
torch.double,
),
)
d: torch.fx.Node = _graph.create_node(
"call_method",
"ge",
args=(
a,
b,
b,
),
)
_graph.output((d, c))
def get_index():
return ""
submodules = {"get_index": get_index}
graph_lowering = GraphLowering(
torch.fx.GraphModule(submodules, _graph),
shape_env=None,
num_static_inputs=0,
)
def set_opt_dtype(graph):
for node in graph.nodes:
if node.target == "constant":
if OptimizationContext.key in node.meta:
opt_ctx = node.meta[OptimizationContext.key]
else:
opt_ctx = OptimizationContext()
opt_ctx.dtype = node.args[-1]
node.meta[OptimizationContext.key] = opt_ctx
with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler(
graph_lowering
):
# The moset inner loop variable is used in the index_expr
tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float)
with CppVecKernelChecker(
args=None, num_threads=1, tiling_factor=tiling_factor
) as vec_checker:
i32_iinfo = np.iinfo(np.int32)
f32_iinfo = np.finfo(np.float32)
set_opt_dtype(_graph)
InterpreterShim(_graph, submodules).run(
V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max
)
self.assertTrue(vec_checker.simd_vec)
vec_checker.simd_vec = True
set_opt_dtype(_graph)
InterpreterShim(_graph, submodules).run(
V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min
)
self.assertTrue(vec_checker.simd_vec)
vec_checker.simd_vec = True
set_opt_dtype(_graph)
InterpreterShim(_graph, submodules).run(
V.get_ops_handler(), i32_iinfo.min, np.inf
)
self.assertTrue(vec_checker.simd_vec)
vec_checker.simd_vec = True
set_opt_dtype(_graph)
InterpreterShim(_graph, submodules).run(
V.get_ops_handler(), i32_iinfo.min, -np.inf
)
self.assertTrue(vec_checker.simd_vec)
vec_checker.simd_vec = True
set_opt_dtype(_graph)
InterpreterShim(_graph, submodules).run(
V.get_ops_handler(), i32_iinfo.min - 1, f32_iinfo.min
)
self.assertFalse(vec_checker.simd_vec)
vec_checker.simd_vec = True
set_opt_dtype(_graph)
InterpreterShim(_graph, submodules).run(
V.get_ops_handler(), i32_iinfo.max + 1, f32_iinfo.max
)
self.assertFalse(vec_checker.simd_vec)
vec_checker.simd_vec = True
set_opt_dtype(_graph)
InterpreterShim(_graph, submodules).run(
V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min * (1 + 1e-5)
)
self.assertFalse(vec_checker.simd_vec)
vec_checker.simd_vec = True
set_opt_dtype(_graph)
InterpreterShim(_graph, submodules).run(
V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max * (1 + 1e-5)
)
self.assertFalse(vec_checker.simd_vec)
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_cpp_vec_index_expr_checker(self):
_graph: torch.fx.Graph = torch.fx.Graph()
a: torch.fx.Node = _graph.create_node("placeholder", "ops")
b: torch.fx.Node = _graph.create_node("call_module", "get_index", args=())
c: torch.fx.Node = _graph.create_node(
"call_method",
"index_expr",
args=(
a,
b,
torch.int64,
),
)
d: torch.fx.Node = _graph.create_node(
"call_method",
"ge",
args=(
a,
c,
c,
),
)
_graph.output(d)
def get_index():
return ""
submodules = {"get_index": get_index}
graph_lowering = GraphLowering(
torch.fx.GraphModule(submodules, _graph),
shape_env=None,
num_static_inputs=0,
)
with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler(
graph_lowering
):
itervars = [sympy.Symbol("i"), sympy.Symbol("j"), sympy.Symbol("k")]
tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float)
# The moset inner loop variable is used in the index_expr
with CppVecKernelChecker(
args=None, num_threads=1, tiling_factor=tiling_factor
) as vec_checker:
def get_index():
return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1]
ranges = [0, 100, 200]
vec_checker.itervars = itervars[:2]
vec_checker.ranges = ranges[:2]
submodules = {"get_index": get_index}
InterpreterShim(_graph, submodules).run(V.get_ops_handler())
self.assertFalse(vec_checker.simd_vec)
# Most inner loop variable irrevalant
with CppVecKernelChecker(
args=None, num_threads=1, tiling_factor=tiling_factor
) as vec_checker:
def get_index():
return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1]
ranges = [0, 100, 200]
vec_checker.itervars = itervars
vec_checker.ranges = ranges
submodules = {"get_index": get_index}
InterpreterShim(_graph, submodules).run(V.get_ops_handler())
self.assertTrue(vec_checker.simd_vec)
i32_iinfo = np.iinfo(np.int32)
_max_value = i32_iinfo.max + 1
ranges = [_max_value, _max_value, _max_value]
# Most inner loop variable irrevalant but max value is greater than
# the max value of INT32
with CppVecKernelChecker(
args=None, num_threads=1, tiling_factor=tiling_factor
) as vec_checker:
def get_index():
return itervars[0]
submodules = {"get_index": get_index}
vec_checker.itervars = itervars
vec_checker.ranges = ranges
InterpreterShim(_graph, submodules).run(V.get_ops_handler())
self.assertFalse(vec_checker.simd_vec)
# Most inner loop variable irrevalant but min value is greater than
# the min value of INT32
with CppVecKernelChecker(
args=None, num_threads=1, tiling_factor=tiling_factor
) as vec_checker:
def get_index():
return -itervars[0] - 2
submodules = {"get_index": get_index}
vec_checker.itervars = itervars
vec_checker.ranges = ranges
InterpreterShim(_graph, submodules).run(V.get_ops_handler())
self.assertFalse(vec_checker.simd_vec)
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_maxpool2d_cpu_only(self):
for dtype in vec_dtypes:
input = torch.randn(26, 32, 112, 112, dtype=dtype).to(
memory_format=torch.channels_last
)
maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def func(x):
return maxpool(x)
with patch.object(config.cpp, "simdlen", None):
torch._dynamo.reset()
metrics.reset()
self.common(func, (input,))
assert metrics.generated_cpp_vec_kernel_count == 1
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_maxpool2d_with_pre_loop_collapse_cpu_only(self):
x1 = torch.randn(2, 3, 20, 20).to(memory_format=torch.channels_last)
x2 = torch.randn(2, 3, 20, 20).to(memory_format=torch.channels_last)
maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
def func(x1, x2):
y = x1 + x2
return maxpool(y)
with patch.object(config.cpp, "simdlen", None):
torch._dynamo.reset()
metrics.reset()
self.common(func, (x1, x2))
assert metrics.generated_cpp_vec_kernel_count == 2
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_sign_cpu_only(self):
def fn(x):
return torch.sign(x)
for dtype in vec_dtypes:
x = torch.randn((2, 9), dtype=dtype)
x[0, 0] = torch.nan
x[1, -1] = torch.nan
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_reduction_cpu_only(self):
def fn(x):
return torch.argmax(x, -1)
for dtype in vec_dtypes:
x = torch.randn((10, 10), dtype=dtype)
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 0
# Currently, we enabled AVX2 and AVX512 for vectorization. If the platform is not
# supported, the vectorization will not work and skip this test case. For ARM or
# other platforms support, we just need to add the ISA info to the supported_vector_isa
# and include proper aten vectorization head file.
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
@patch("torch.cuda.is_available", lambda: False)
def test_vec_kernel_cpu_only(self):
def fn(x1, x2):
# Current, there are some limitations as follows.
# rsqrt:
# assert [both a fallback and a decomp for same kernel: aten.rsqrt.default]
# round:
# couldn't find symbolic meta function/decomposition
# fmod/logical_and/logic_or:
# vec kernel has not support to_type
x = torch.abs(x1)
x = torch.sin(x)
x = torch.neg(x)
x = torch.square(x)
x = torch.sigmoid(x)
x = torch.relu(x)
x = torch.cos(x)
x = torch.exp(x)
x = torch.sqrt(x)
x = torch.add(x, x1)
x = torch.sub(x, x2)
x = torch.mul(x, x1)
x = torch.div(x, x1)
x = torch.pow(x, 10)
x = torch.log(x)
x = torch.floor(x)
x = torch.ceil(x)
x = torch.trunc(x)
x = torch.lgamma(x)
x = torch.fmod(x, x2)
x = torch.sign(x)
res = x + x2
return res
for dtype in vec_dtypes:
torch.manual_seed(0)
x1 = torch.randn((5, 20), dtype=dtype)
x2 = torch.randn((5, 20), dtype=dtype)
tol = 1e-2 if dtype == torch.bfloat16 else 1e-4
with config.patch({"cpp.simdlen": 1}):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x1, x2))
assert metrics.generated_cpp_vec_kernel_count == 0
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
self.common(fn, (x1, x2))
assert metrics.generated_cpp_vec_kernel_count == 1
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
x1 = torch.randn(10, 20).permute(1, 0)
x2 = torch.randn((20, 10))
self.common(fn, (x1, x2))
assert metrics.generated_cpp_vec_kernel_count == 2
torch._dynamo.reset()
metrics.reset()
x1 = torch.randn((10, 7))
x2 = torch.randn((10, 7))
self.common(fn, (x1, x2))
assert metrics.generated_cpp_vec_kernel_count == 1
@unittest.skipIf(
sys.platform != "linux", "cpp kernel profile only support linux now"
)
@patch("torch.cuda.is_available", lambda: False)
@config.patch({"cpp.enable_kernel_profile": True})
@config.patch({"cpp.descriptive_names": "original_aten"})
def test_cpp_kernel_profile(self):
from torch.profiler import profile
@torch._dynamo.optimize("inductor", nopython=True)
def fn(a, b):
return a + b
a = torch.rand((100,))
b = torch.rand((100,))
with profile() as prof:
fn(a, b)
kernel_profile_events = []
for e in prof.profiler.function_events:
if "cpp_fused_add_0" in e.name:
kernel_profile_events.append(e.name)
assert len(kernel_profile_events) > 0
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
def test_channel_shuffle_cl_output(self):
"""code and shape extracted from shufflenet_v2_x1_0"""
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.size()
channels_per_group = num_channels // groups
x = x.view(batchsize, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batchsize, -1, height, width)
return x.contiguous(memory_format=torch.channels_last)
for simdlen in (None, 256, 1):
with config.patch({"cpp.simdlen": simdlen}):
torch._dynamo.reset()
metrics.reset()
x = torch.randn(64, 58, 28, 28)
self.common(channel_shuffle, (x, 2))
if simdlen != 1:
assert metrics.generated_cpp_vec_kernel_count == 2
@slowTest
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
def test_transpose_with_norm(self):
"""a sub-module from TIMM gmlp_s16_224"""
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(
in_features=256, out_features=1536, bias=True
)
self.act = torch.nn.GELU()
self.norm = torch.nn.LayerNorm(768)
self.proj = torch.nn.Linear(196, 196)
self.fc = torch.nn.Linear(in_features=768, out_features=256, bias=True)
def forward(self, x):
x = self.linear(x)
x = self.act(x)
u, v = x.chunk(2, dim=-1)
v = self.norm(v)
v = self.proj(v.transpose(-1, -2))
y = u * v.transpose(-1, -2)
return self.fc(y)
x = torch.randn(128, 196, 256)
for simdlen in (None, 256, 1):
with config.patch({"cpp.simdlen": simdlen}):
for eval_mode in [True, False]:
torch._dynamo.reset()
metrics.reset()
m = Model().eval() if eval_mode else Model()
self.common(m, (x,))
if simdlen != 1:
assert metrics.generated_cpp_vec_kernel_count == 6
@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
)
def test_transpose_copy(self):
def fn(a):
return a.t().contiguous()
for simdlen in (None, 256, 1):
with config.patch({"cpp.simdlen": simdlen}):
for dtype in (torch.float, torch.bfloat16):
for shape in (
(7, 7),
(8, 8),
(9, 9),
(16, 16),
(17, 17),
(32, 32),
(33, 33),
):
torch._dynamo.reset()
metrics.reset()
x = torch.randn(shape, dtype=dtype)
self.common(fn, (x,))
if simdlen != 1:
assert metrics.generated_cpp_vec_kernel_count == 2
def test_horizontal_fusion(self):
def fn(a, b, c, idx):
_a = torch.index_select(a, dim=0, index=idx)
_b = torch.index_select(b, dim=0, index=idx)
_c = torch.index_select(c, dim=0, index=idx)
return _a, _b, _c
with config.patch({"cpp.max_horizontal_fusion_size": 0}):
metrics.reset()
torch._dynamo.reset()
a = torch.randn(size=(4, 16), dtype=torch.bfloat16)
b = torch.randn(size=(4, 16), dtype=torch.bfloat16)
c = torch.randn(size=(4, 16), dtype=torch.bfloat16)
idx = torch.zeros(size=[4], dtype=torch.int64)
opt_fn = torch._dynamo.optimize("inductor")(fn)
opt_fn(a, b, c, idx)
self.assertEqual(metrics.generated_kernel_count, 3)
self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx)))
with config.patch({"cpp.max_horizontal_fusion_size": 1}):
metrics.reset()
torch._dynamo.reset()
a = torch.randn(size=(4, 32), dtype=torch.bfloat16)
b = torch.randn(size=(4, 32), dtype=torch.bfloat16)
c = torch.randn(size=(4, 32), dtype=torch.bfloat16)
idx = torch.zeros(size=[4], dtype=torch.int64)
opt_fn = torch._dynamo.optimize("inductor")(fn)
opt_fn(a, b, c, idx)
self.assertEqual(metrics.generated_kernel_count, 3)
self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx)))
with config.patch({"cpp.max_horizontal_fusion_size": 2}):
metrics.reset()
torch._dynamo.reset()
a = torch.randn(size=(4, 64), dtype=torch.bfloat16)
b = torch.randn(size=(4, 64), dtype=torch.bfloat16)
c = torch.randn(size=(4, 64), dtype=torch.bfloat16)
idx = torch.zeros(size=[4], dtype=torch.int64)
opt_fn = torch._dynamo.optimize("inductor")(fn)
opt_fn(a, b, c, idx)
print(metrics.generated_kernel_count)
self.assertEqual(metrics.generated_kernel_count, 2)
self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx)))
with config.patch({"cpp.max_horizontal_fusion_size": 3}):
metrics.reset()
torch._dynamo.reset()
a = torch.randn(size=(4, 128), dtype=torch.bfloat16)
b = torch.randn(size=(4, 128), dtype=torch.bfloat16)
c = torch.randn(size=(4, 128), dtype=torch.bfloat16)
idx = torch.zeros(size=[4], dtype=torch.int64)
opt_fn = torch._dynamo.optimize("inductor")(fn)
opt_fn(a, b, c, idx)
self.assertEqual(metrics.generated_kernel_count, 1)
self.assertTrue(same(fn(a, b, c, idx), opt_fn(a, b, c, idx)))
def test_lowp_fp_neg_abs(self):
def fn(x):
return x.neg().abs()
for dtype in _lowp_fp_dtypes:
metrics.reset()
x = torch.randn(100, 100).to(dtype)
opt_fn = torch._dynamo.optimize("inductor")(fn)
self.assertTrue(same(fn(x), opt_fn(x)))
assert metrics.cpp_to_dtype_count == 0
assert metrics.generated_cpp_vec_kernel_count == 1
def test_transpose_non_contiguous(self):
def fn(a):
# From part of timm HaloAttn:
# (https://github.com/rwightman/pytorch-image-models/blob/main/timm/layers/halo_attn.py#L97).
# Fixed https://github.com/pytorch/pytorch/issues/94269 accuracy issue.
as_strided = torch.ops.aten.as_strided.default(
a, [1, 384, 2, 20, 12], [153600, 1, 61440, 384, 7680]
)
as_strided_1 = torch.ops.aten.as_strided.default(
as_strided,
[1, 384, 2, 2, 12, 12],
[153600, 1, 61440, 3072, 7680, 384],
)
clone_1 = torch.ops.aten.clone.default(
as_strided_1, memory_format=torch.contiguous_format
)
_unsafe_view_1 = torch.ops.aten._unsafe_view.default(
clone_1, [8, 48, 4, 144]
)
permute_2 = torch.ops.aten.permute.default(_unsafe_view_1, [0, 2, 3, 1])
split_with_sizes = torch.ops.aten.split_with_sizes.default(
permute_2, [16, 32], -1
)
getitem = split_with_sizes[0]
getitem_1 = split_with_sizes[1]
permute_3 = torch.ops.aten.permute.default(getitem, [0, 1, 3, 2])
expand_1 = torch.ops.aten.expand.default(permute_3, [8, 4, 16, 144])
clone_3 = torch.ops.aten.clone.default(
expand_1, memory_format=torch.contiguous_format
)
return clone_3
metrics.reset()
x = torch.randn(1, 384, 20, 20).to(memory_format=torch.channels_last)
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
def test_non_contiguous_index_with_constant_stride(self):
def fn(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
metrics.reset()
x = torch.randn(1, 32, 16, 68)
opt_fn = torch._dynamo.optimize("inductor")(fn)
self.assertTrue(same(fn(x), opt_fn(x)))
assert metrics.generated_cpp_vec_kernel_count == 2
def test_invalid_index_of_empty_tensor(self):
def fn(a):
b = a[[0]]
return b
a = torch.tensor([])
with self.assertRaises(RuntimeError):
torch.compile(fn)(a)
def test_ir_node_str(self):
@torch.compile
def fn(x: torch.Tensor) -> torch.Tensor:
return x.sin(), torch.nn.Softmax(dim=1)(x.cos())
def run_node_alt(*args, **kwargs):
rv = run_node(*args, **kwargs)
strings.append(str(rv))
return rv
strings = []
run_node = GraphLowering.run_node
with patch.object(GraphLowering, "run_node", run_node_alt):
fn(torch.randn([8, 128]))
self.assertGreater(len(strings), 3)
def test_vertical_sum_cpu_only(self):
def fn1(a):
return a.sum(dim=0)
def fn2(a):
return a.sum(dim=1)
metrics.reset()
x = torch.randn(100, 100)
self.common(fn1, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
metrics.reset()
x = torch.randn(100, 100, 100)
self.common(fn2, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
def test_transpose_vertical_sum_cpu_only(self):
def fn(a, b):
c = a * b
return c.sum(dim=1)
metrics.reset()
x = torch.randn(100, 50, 50)
y = torch.randn(100, 50, 50).transpose(1, 2)
self.common(fn, (x, y))
assert metrics.generated_cpp_vec_kernel_count == 2
def test_transpose_sum2d_cpu_only(self):
def fn(a, b):
c = a * b
return c.sum()
metrics.reset()
x = torch.randn(50, 50)
y = torch.randn(50, 50).transpose(0, 1)
self.common(fn, (x, y))
assert metrics.generated_cpp_vec_kernel_count == 2
def test_transpose_sum_outer(self):
# https://github.com/pytorch/pytorch/issues/98573
def fn(a):
return a.transpose(2, 3).sum(dim=1).contiguous()
metrics.reset()
x = torch.randn(10, 50, 50, 50)
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
def test_to_dtype_bool_float(self):
# https://github.com/pytorch/pytorch/issues/100800
def f(a):
return torch.where(
torch.ones_like(a).to(torch.bool),
torch.zeros_like(a),
torch.ones_like(a) * 2,
)
self.common(f, (torch.ones(16),))
def test_to_dtype_float_bool(self):
# https://github.com/pytorch/pytorch/issues/100466
def f(a):
a = a * torch.tensor(a >= 0, dtype=torch.float32)
return a
x = torch.rand(16)
self.common(f, (x,))
def test_constant_store(self):
# https://github.com/pytorch/pytorch/issues/104515
def f(a):
a[0, [3, 3]] = -float("inf")
return a
x = torch.rand(4, 5)
self.common(f, (x,))
def test_to_channels_last_lowp_fp(self):
def f(a):
return a.to(memory_format=torch.channels_last)
for dtype in _lowp_fp_dtypes:
x = torch.rand(2, 3, 14, 14).to(dtype)
self.common(f, (x,))
def test_broadcast_mul_lowp_fp(self):
def f(a, b):
return a * b
for dtype in _lowp_fp_dtypes:
a = torch.randn(2, 16, 16).to(dtype)
b = torch.randn(2, 1, 1).to(dtype)
self.common(f, (a, b))
def test_linear_buffer_reuse(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(16, 16)
self.tanh = torch.nn.Tanh()
self.linear2 = torch.nn.Linear(16, 16)
def forward(self, x):
x = self.linear1(x)
x = self.tanh(x)
x = self.linear2(x)
return x
mod = M().eval()
v = torch.randn(1, 16)
with torch.no_grad():
def compile_fx_wrapper(model_, example_inputs_):
return compile_fx(model_, example_inputs_)
def run(*ex, **kwargs):
return mod(*ex, **kwargs)
run = torch._dynamo.optimize(compile_fx_wrapper)(run)
code = run_and_get_cpp_code(run, v)
self.assertFalse("= as_strided(" in code)
self.assertEqual(run(*v), mod(*v))
@config.patch(inplace_buffers=True)
def test_in_out_buffer(self):
def fn(x, y):
z = torch.matmul(x, y.transpose(-1, -2)) / 8.0
return z
inps = [torch.randn(1, 2, 8, 4), torch.randn(1, 2, 8, 4)]
fn_opt = torch._dynamo.optimize("inductor")(fn)
code = run_and_get_cpp_code(fn_opt, *inps)
self.assertTrue("in_out_ptr" in code)
self.assertEqual(fn_opt(*inps), fn(*inps))
def test_eliminate_meaningless_copy(self):
def fn(x1, x2):
permute = torch.ops.aten.permute.default(x2, [0, 2, 1, 3])
clone = torch.ops.aten.clone.default(
permute, memory_format=torch.contiguous_format
)
view = torch.ops.aten.view.default(clone, [1024, -1, 32])
bmm = torch.ops.aten.bmm.default(view, x1)
permute = torch.ops.aten.permute.default(view, [0, 2, 1])
return (bmm, permute)
metrics.reset()
self.common(
fn,
[
rand_strided(
(1024, 32, 128), (4096, 1, 32), device="cpu", dtype=torch.float32
),
rand_strided(
(64, 128, 16, 32),
(65536, 512, 32, 1),
device="cpu",
dtype=torch.float32,
),
],
)
self.assertEqual(metrics.generated_kernel_count, 1)
def test_scalar_mul_bfloat16(self):
def f(x):
return torch.ops.aten.mul.Tensor(x, 1.7015043497085571)
metrics.reset()
x = torch.randn(4, 5, dtype=torch.bfloat16)
self.common(f, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
def test_bf16_zeros(self):
def fn():
x = torch.zeros(1, 1, 32, dtype=torch.bfloat16)
return x
self.common(fn, ())
def test_select_tiliing_with_index_expr(self):
def fn(x, y):
x = torch.ops.aten.view.default(x, [8, 8, 8, 3136])
x = torch.ops.aten.permute.default(x, [0, 1, 3, 2])
y = torch.ops.aten.mul.Tensor(y, x)
return torch.ops.aten.constant_pad_nd.default(y, [0, 0, 1, 0, 0, 0], 0.0)
x = torch.randn(8, 64, 56, 56)
y = torch.randn(8, 8, 3136, 8)
self.common(fn, (x, y))
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled")
@patch("torch.cuda.is_available", lambda: False)
@config.patch(freezing=True)
def test_linear_with_no_default_contiguous_input(self):
mod = torch.nn.Sequential(torch.nn.Linear(16, 16)).eval()
temp = torch.randn(1, 16, 1, 1)
v = torch.as_strided(temp, [1, 16], [0, 1], 0)
self.assertTrue(v.is_contiguous())
with torch.no_grad():
self.common(
mod,
(v,),
)
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
mod = mod.to(torch.bfloat16)
v = v.to(torch.bfloat16)
with torch.no_grad():
self.common(
mod,
(v,),
)
@patch("torch.cuda.is_available", lambda: False)
@config.patch(freezing=True)
def test_linear_with_reshape(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(16, 16, bias=False)
def forward(self, x):
x = self.linear(x)
return x.view(4, 4, 4)
mod = M().eval()
v = torch.randn(4, 16)
with torch.no_grad():
torch._dynamo.reset()
metrics.reset()
self.common(
mod,
(v,),
)
assert metrics.generated_kernel_count == 0
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
from torch.testing._internal.inductor_utils import HAS_CPU
if HAS_CPU and not IS_MACOS:
run_tests(needs="filelock")