blob: 6de2d05a4810c106c7eb57ed76df86a84476d545 [file] [log] [blame]
# Owner(s): ["module: inductor"]
import copy
import unittest
import torch
import torch._dynamo.config as dynamo_config
import torch._inductor.config as inductor_config
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import count_calls, counters
from torch._higher_order_ops.out_dtype import out_dtype
from torch._inductor.fx_passes import joint_graph
from torch._inductor.fx_passes.serialized_patterns.central_index import (
get_serialized_pattern,
)
from torch._inductor.pattern_matcher import (
_TargetExpr,
Arg,
CallFunction,
gen_pattern,
KeywordArg,
Match,
PatternExpr,
PatternMatcherPass,
PatternPrettyPrinter,
register_graph_pattern,
)
from torch._inductor.utils import run_and_get_code
from torch._inductor.virtualized import V
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm
from torch.testing._internal.inductor_utils import HAS_CUDA
class TestPatternMatcher(TestCase):
def common(
self,
fn,
args,
expected_matches,
expected_nodes,
additional_check=lambda code: None,
):
counters.clear()
torch.manual_seed(42)
expected = fn(*args)
torch.manual_seed(42)
actual, codes = run_and_get_code(torch.compile(fn), *args)
if len(codes) == 1:
codes = codes[0]
torch.testing.assert_close(actual, expected)
self.assertEqual(
counters["inductor"]["pattern_matcher_count"], expected_matches
)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], expected_nodes)
additional_check(codes)
counters.clear()
def test_mm_plus_mm(self):
def fn(a, b, c, d):
return torch.add(torch.mm(a, b), torch.mm(c, d))
args_list = [
(
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
),
# https://github.com/pytorch/pytorch/issues/100670.
(
torch.randn(1, 4, device="cuda"),
torch.randn(4, 2, device="cuda"),
torch.randn(1, 2, device="cuda"),
torch.randn(2, 1, device="cuda"),
),
(
torch.randn(1, 2, device="cuda"),
torch.randn(2, 1, device="cuda"),
torch.randn(1, 4, device="cuda"),
torch.randn(4, 2, device="cuda"),
),
(
torch.randn(1, 4, device="cuda"),
torch.randn(4, 2, device="cuda"),
torch.randn(1, 5, device="cuda"),
torch.randn(5, 2, device="cuda"),
),
]
for args in args_list:
self.common(fn, args, 1, 3)
def _test_fused_int_mm_mul_impl(self, fn, args, fused_int_mm_mul_expected=True):
torch._dynamo.reset()
counters.clear()
ref = fn(*args)
test, (code,) = run_and_get_code(torch.compile(fn, mode="max-autotune"), *args)
self.assertEqual("fused_int_mm_mul" in code, fused_int_mm_mul_expected)
if fused_int_mm_mul_expected:
indices = ~ref.isinf()
torch.testing.assert_close(
ref[indices], test[indices]
) # also checks that dtype is correct
@skipIfRocm
@unittest.skipIf(not SM80OrLater, "need sm_80")
@inductor_config.patch(force_fuse_int_mm_with_mul=True)
def test_fused_int_mm_mul(self):
def fn1(a, b, c):
return out_dtype(torch.ops.aten.mm.default, torch.int32, a, b) * c
def fn2(a, b, c):
return (out_dtype(torch.ops.aten.mm.default, torch.int32, a, b) * c).to(
torch.bfloat16
)
args_list = [
(
torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"),
torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"),
torch.randn((32, 1), dtype=torch.float16, device="cuda") * 0 + 0.5,
),
(
torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"),
torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"),
torch.randn((1, 8), dtype=torch.bfloat16, device="cuda"),
),
(
torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"),
torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"),
torch.randn((1, 8), dtype=torch.float32, device="cuda"),
),
]
for args in args_list:
self._test_fused_int_mm_mul_impl(fn1, args, True)
self._test_fused_int_mm_mul_impl(fn2, args, True)
@skipIfRocm
@unittest.skipIf(not SM80OrLater, "need sm_80")
@inductor_config.patch(force_fuse_int_mm_with_mul=True)
def test_fused_int_mm_mul_gating(self):
def fn1(a, b, c):
return out_dtype(torch.ops.aten.mm.default, torch.int32, a, b) * c
args1 = (
torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"),
torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"),
torch.randn((8), dtype=torch.float32, device="cuda"),
)
args2 = (
torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"),
torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"),
torch.randn((32, 1), dtype=torch.float16, device="cuda"),
)
self._test_fused_int_mm_mul_impl(fn1, args1, False)
self._test_fused_int_mm_mul_impl(fn1, [arg.cpu() for arg in args2], False)
inductor_config.force_fuse_int_mm_with_mul = False
self._test_fused_int_mm_mul_impl(fn1, args2, False)
def _test_mixed_impl(self, fn, args, mixed_mm_expected, fallback_mixed_mm_expected):
torch._dynamo.reset()
counters.clear()
ref = fn(*args)
test, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(ref, test)
self.assertEqual("mixed_mm" in code, mixed_mm_expected)
self.assertEqual("fallback_mixed_mm" in code, fallback_mixed_mm_expected)
@unittest.skipIf(not SM80OrLater, "need sm_80")
@inductor_config.patch(force_mixed_mm=True)
def test_mixed_mm(self):
def fn(a, b):
return torch.mm(a, b.to(a.dtype))
args_list = [
(
torch.randn(8, 8, device="cuda"),
torch.randint(-128, 127, (8, 8), dtype=torch.int8, device="cuda"),
),
(
torch.randn(8, 2, device="cuda", dtype=torch.bfloat16),
torch.randint(-128, 127, (2, 8), dtype=torch.int8, device="cuda"),
),
(
torch.randn(8, 5, device="cuda", dtype=torch.float16),
torch.randint(0, 255, (5, 2), dtype=torch.uint8, device="cuda"),
),
(
torch.randn(8, 8, device="cuda", dtype=torch.float32),
torch.randn(8, 8, device="cuda", dtype=torch.bfloat16),
),
]
for args in args_list:
self._test_mixed_impl(fn, args, True, False)
@unittest.skipIf(not SM80OrLater, "need sm_80")
@inductor_config.patch(force_mixed_mm=True)
def test_mixed_mm_bad_cases(self):
def fn(a, b):
return torch.mm(a, b.to(a.dtype))
args_list = [
(
torch.randn(8, 8, device="cuda", dtype=torch.float16),
torch.randint(-128, 127, (2, 8), dtype=torch.int8, device="cuda").t(),
),
(
torch.randn(8, 8, device="cuda", dtype=torch.bfloat16),
torch.randint(0, 255, (2, 8), dtype=torch.uint8, device="cuda").t(),
),
]
for args in args_list:
self._test_mixed_impl(fn, args, True, True)
@unittest.skipIf(not SM80OrLater, "need sm_80")
@inductor_config.patch(force_mixed_mm=True, max_autotune_gemm=True)
def test_mixed_mm_epi_works(self):
def fn(a, b, c, d):
return torch.mm(a, b.to(a.dtype)) * c + d
args_list = [
(
torch.randn(8, 8, device="cuda"),
torch.randint(-128, 127, (8, 8), dtype=torch.int8, device="cuda"),
torch.randn(8, device="cuda"),
torch.randn(8, device="cuda"),
),
(
torch.randn(8, 2, device="cuda", dtype=torch.bfloat16),
torch.randint(-128, 127, (2, 8), dtype=torch.int8, device="cuda"),
torch.randn(8, device="cuda", dtype=torch.bfloat16),
torch.randn(8, device="cuda", dtype=torch.bfloat16),
),
(
torch.randn(8, 5, device="cuda", dtype=torch.float16),
torch.randint(0, 255, (5, 2), dtype=torch.uint8, device="cuda"),
torch.randn(2, device="cuda", dtype=torch.float16),
torch.randn(2, device="cuda", dtype=torch.float16),
),
]
for args in args_list:
self._test_mixed_impl(fn, args, True, False)
@unittest.skipIf(not SM80OrLater, "need sm_80")
def test_mixed_mm_gating(self):
def fn(a, b):
return torch.mm(a, b.to(a.dtype))
args = (
torch.randn(8, 8, device="cuda"),
torch.randint(-128, 127, (8, 8), dtype=torch.int8, device="cuda"),
)
# will ignore the mixed_mm code (including fallback)
with inductor_config.patch({"force_mixed_mm": False, "use_mixed_mm": False}):
self._test_mixed_impl(fn, args, False, False)
# will use fallback_mixed_mm kernel due to no gemm_autotune
with inductor_config.patch({"force_mixed_mm": False, "use_mixed_mm": True}):
self._test_mixed_impl(fn, args, True, True)
# will use mixed_mm kernel
with inductor_config.patch({"force_mixed_mm": True, "use_mixed_mm": False}):
self._test_mixed_impl(fn, args, True, False)
# shows that use_mixed_mm doesn't do anything if foce_mixed_mm is set
with inductor_config.patch({"force_mixed_mm": True, "use_mixed_mm": True}):
self._test_mixed_impl(fn, args, True, False)
@inductor_config.patch(use_mixed_mm=True)
def test_mixed_mm_cpu(self):
def fn(a, b):
return torch.mm(a, b.to(a.dtype))
args = (
torch.randn(8, 8),
torch.randint(-128, 127, (8, 8), dtype=torch.int8),
)
self._test_mixed_impl(fn, args, False, False)
@unittest.skipIf(not SM80OrLater, "need sm_80")
@inductor_config.patch(use_mixed_mm=True)
def test_uint4x2_mixed_mm(self):
def fn(a, b):
return torch.mm(
a,
torch.cat((b & 0xF, b >> 4), 1)
.reshape(-1, b.shape[1])
.to(a.dtype)
.sub(8),
)
args_list = [
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"),
),
(
torch.randn(8, 8, device="cuda", dtype=torch.float16),
torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda")
.t()
.contiguous()
.t(),
),
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.int32, device="cuda"),
),
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.int64, device="cuda"),
),
]
for args in args_list:
torch._dynamo.reset()
counters.clear()
ref = fn(*args)
test, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(ref, test)
self.assertTrue("uint4x2_mixed_mm" in code)
@unittest.skipIf(not SM80OrLater, "need sm_80")
@inductor_config.patch(use_mixed_mm=True)
def test_uint4x2_mixed_mm_epi(self):
def fn(a, b, c, d):
return (
torch.mm(
a,
torch.cat((b & 0xF, b >> 4), 1)
.reshape(-1, b.shape[1])
.to(a.dtype)
.sub(8),
)
* c
+ d
)
args_list = [
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"),
torch.randn(8, device="cuda"),
torch.randn(8, device="cuda"),
),
]
for args in args_list:
torch._dynamo.reset()
counters.clear()
ref = fn(*args)
test, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(ref, test)
self.assertTrue("uint4x2_mixed_mm" in code)
self.assertTrue("fused_add_mm_mul" in code)
@inductor_config.patch(use_mixed_mm=True)
def test_uint4x2_mixed_mm_fail_to_match(self):
def fn(a, b):
return torch.mm(
a,
torch.cat((b & 0xF, b >> 4), 1)
.reshape(-1, b.shape[1])
.to(a.dtype)
.sub(8),
)
args_list = [
( # cpu
torch.randn(8, 8),
torch.randint(0, 255, (4, 8), dtype=torch.uint8),
),
( # int8
torch.randn(8, 8, device="cuda"),
torch.randint(-128, 127, (4, 8), dtype=torch.int8, device="cuda"),
), # we don't match for int8 since numerics
] # for int8 bitshifts don't match between triton and pytorch
for args in args_list:
torch._dynamo.reset()
counters.clear()
ref = fn(*args)
test, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(ref, test)
self.assertFalse("uint4x2_mixed_mm" in code)
@inductor_config.patch(use_mixed_mm=False)
def test_uint4x2_mixed_mm_gating_works(self):
def fn(a, b):
return torch.mm(
a,
torch.cat((b & 0xF, b >> 4), 1)
.reshape(-1, b.shape[1])
.to(a.dtype)
.sub(8),
)
args_list = [
(
torch.randn(8, 8, device="cuda"),
torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"),
),
]
for args in args_list:
torch._dynamo.reset()
counters.clear()
ref = fn(*args)
test, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(ref, test)
self.assertFalse("uint4x2_mixed_mm" in code)
def test_addmm(self):
def fn(a, b, c):
return torch.add(a, torch.mm(b, c)), torch.mm(b, c) + a
args_list = [
(
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
True,
),
(
torch.randn(8, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 8, device="cuda"),
True,
),
(
torch.randn(16, 16, device="cuda"),
torch.randn(1, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
False,
),
(
torch.randn(1, 16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
False,
),
(
4,
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
False,
),
]
for a, b, c, should_fuse in args_list:
torch._dynamo.reset()
counters.clear()
args = (a, b, c)
e1, e2 = fn(*args)
a1, a2 = torch.compile(fn)(*args)
torch.testing.assert_close(a1, e1)
torch.testing.assert_close(a2, e2)
count, nodes = (2, 4) if should_fuse else (0, 0)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], count)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], nodes)
def test_addmm_symbolic_scalar(self):
def fn(m1, m2):
bias = m1.size(0)
return torch.add(bias, torch.mm(m1, m2)), torch.mm(m1, m2) + bias
m1 = torch.randn(16, 16, device="cuda")
m2 = torch.randn(16, 16, device="cuda")
counters.clear()
expect = fn(m1, m2)
actual = torch.compile(fn, dynamic=True)(m1, m2)
self.assertEqual(expect, actual)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
def test_cat_mm(self):
def fn(a, b, c):
return torch.cat(
[
torch.mm(a, b),
torch.mm(b, c),
torch.mm(a, c),
],
1,
)
args = [
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
self.common(fn, args, 2, 5)
def test_cat_addmm(self):
def fn(a, b, c):
return torch.cat(
[
torch.addmm(a, b, c),
torch.addmm(b, c, a),
torch.addmm(c, a, b),
],
1,
)
args = [
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
self.common(fn, args, 2, 5)
def test_cat_slice_cat(self):
def fn(a, b):
cat_1 = torch.ops.aten.cat.default([a, b], 1)
slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19)
return torch.ops.aten.cat.default([cat_1, slice_2], 1)
args = [
torch.randn(2, 32, device="cuda"),
torch.randn(2, 16, device="cuda"),
]
self.common(fn, args, 1, 3)
args = [
torch.randn(2, 8, device="cuda"),
torch.randn(2, 16, device="cuda"),
]
counters.clear()
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
# We don't recompile for dynamic-shape cases.
if dynamo_config.assume_static_by_default:
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3)
# Verify we fallback to non-optimal path for negative `end`.
def fn(a, b):
cat_1 = torch.ops.aten.cat.default([a, b], 1)
slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, -1)
return torch.ops.aten.cat.default([cat_1, slice_2], 1)
args = [
torch.randn(2, 8, device="cuda"),
torch.randn(2, 16, device="cuda"),
]
self.common(fn, args, 1, 3)
def test_pointless_convert(self):
def fn1(x):
x = torch.ops.prims.convert_element_type.default(x, torch.float16)
x = torch.ops.prims.convert_element_type.default(x, torch.float32)
return x
gm = torch.fx.symbolic_trace(fn1)
self.assertEqual(count_calls(gm.graph), 2)
joint_graph.joint_graph_passes(gm)
self.assertEqual(count_calls(gm.graph), 1)
def fn2(x):
x = torch.ops.prims.convert_element_type.default(x, torch.int32)
x = torch.ops.prims.convert_element_type.default(x, torch.float32)
return x
gm = torch.fx.symbolic_trace(fn2)
self.assertEqual(count_calls(gm.graph), 2)
joint_graph.joint_graph_passes(gm)
self.assertEqual(count_calls(gm.graph), 2)
# Constant folding was explicitly turned off due to issue #108388
# Turn it back on for test
@inductor_config.patch(joint_graph_constant_folding=True)
def test_pointless_cumsum(self):
def fn1():
ones = torch.full(
[1, 128], 1, layout=torch.strided, dtype=torch.float32
).to(torch.int64)
return torch.cumsum(ones, 1) * ones
def fn2():
ones = torch.full(
[55, 10], 1, layout=torch.strided, dtype=torch.float32
).to(torch.int64)
return torch.cumsum(ones, 1)
def fn3():
twos = torch.full([5, 4, 3], 2, dtype=torch.int64)
return torch.cumsum(twos, 0)
def fn4():
x = torch.full([100], 0.1, dtype=torch.float32)
return torch.cumsum(x, 0)
def fn5():
t1 = torch.full([2, 4], 1)
t2 = t1.to(dtype=torch.bool)
return torch.cumsum(t2, 1)
def fn6():
x = torch.full([10, 10], True, dtype=torch.int32)
return torch.cumsum(x, 1)
for fn in (fn1, fn2, fn3, fn4, fn5, fn6):
result, (code,) = run_and_get_code(torch.compile(fn, fullgraph=True))
self.assertNotIn("aten.cumsum", code)
self.assertEqual(result, fn())
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
counters.clear()
def test_splitwithsizes_cat(self):
# Good case
def fn(a):
split_with_sizes = torch.ops.aten.split_with_sizes.default(a, [8, 24], 1)
getitem = split_with_sizes[0]
getitem_1 = split_with_sizes[1]
cat = torch.ops.aten.cat.default([getitem, getitem_1], 1)
return cat**2
args = [
torch.randn(2, 32, device="cuda"),
]
self.common(fn, args, 1, 4)
# Not all getitems are passed to cat
def fn(a):
split_with_sizes = torch.ops.aten.split_with_sizes.default(a, [8, 8, 16], 1)
getitem = split_with_sizes[0]
getitem_1 = split_with_sizes[1]
getitem_2 = split_with_sizes[2]
cat = torch.ops.aten.cat.default([getitem, getitem_1], 1)
return cat**2 + getitem_2
args = [
torch.randn(2, 32, device="cuda"),
]
self.common(fn, args, 0, 0)
# Different dimensions (TODO this case should be handled by replacing with a reshape)
def fn(a):
split_with_sizes = torch.ops.aten.split_with_sizes.default(
a, [8, 8, 8, 8], 1
)
cat = torch.ops.aten.cat.default(split_with_sizes, 0)
return cat**2
args = [
torch.randn(2, 32, device="cuda"),
]
self.common(fn, args, 0, 0)
# https://github.com/pytorch/pytorch/issues/99686.
def fn(a):
x = torch.ops.aten.split_with_sizes.default(a, [3, 2, 3], dim=1)
cat = torch.ops.aten.cat.default([x[1], x[0], x[2]], dim=1)
return cat
args = [
torch.randn(1, 8, device="cuda"),
]
self.common(fn, args, 0, 0)
def test_cat_splitwithsizes(self):
# good case
def fn(a, b, c):
cat = torch.ops.aten.cat.default([a, b, c], 1)
split_with_sizes = torch.ops.aten.split_with_sizes.default(
cat, [2, 3, 5], 1
)
return [s**2 for s in split_with_sizes]
args = [
torch.randn(2, 2, device="cuda"),
torch.randn(2, 3, device="cuda"),
torch.randn(2, 5, device="cuda"),
]
self.common(fn, args, 1, 2)
# cat node has other users
def fn(a, b, c):
cat = torch.ops.aten.cat.default([a, b, c], 1)
split_with_sizes = torch.ops.aten.split_with_sizes.default(
cat, [2, 3, 5], 1
)
return [s**2 for s in split_with_sizes] + [cat**3]
args = [
torch.randn(2, 2, device="cuda"),
torch.randn(2, 3, device="cuda"),
torch.randn(2, 5, device="cuda"),
]
self.common(fn, args, 0, 0)
# cat and split dims are different
def fn(a, b, c):
cat = torch.ops.aten.cat.default([a, b, c], 1)
split_with_sizes = torch.ops.aten.split_with_sizes.default(
cat, [2, 3, 5], 0
)
return [s**2 for s in split_with_sizes]
args = [
torch.randn(10, 2, device="cuda"),
torch.randn(10, 3, device="cuda"),
torch.randn(10, 5, device="cuda"),
]
self.common(fn, args, 0, 0)
# cat and split lenghts are different
def fn(a, b, c):
cat = torch.ops.aten.cat.default([a, b, c], 1)
split_with_sizes = torch.ops.aten.split_with_sizes.default(cat, [5, 5], 1)
return [s**2 for s in split_with_sizes]
args = [
torch.randn(2, 2, device="cuda"),
torch.randn(2, 3, device="cuda"),
torch.randn(2, 5, device="cuda"),
]
self.common(fn, args, 0, 0)
# cat input sizes and split sizes are different
def fn(a, b, c):
cat = torch.ops.aten.cat.default([a, b, c], 1)
split_with_sizes = torch.ops.aten.split_with_sizes.default(
cat, [2, 5, 3], 1
)
return [s**2 for s in split_with_sizes]
args = [
torch.randn(2, 2, device="cuda"),
torch.randn(2, 3, device="cuda"),
torch.randn(2, 5, device="cuda"),
]
self.common(fn, args, 0, 0)
def test_match_with_mutation(self):
counter = 0
test_pass = PatternMatcherPass(prevent_match_across_mutations=True)
@register_graph_pattern(
CallFunction(
torch.add, KeywordArg("x"), CallFunction(torch.sin, KeywordArg("x"))
),
pass_dict=test_pass,
)
def _test(match, x):
nonlocal counter
counter += 1
def fn0(x, y):
a = torch.sin(x)
b = torch.add(x, a)
return b
def fn1(x, y):
a = torch.sin(x)
x.copy_(y)
b = torch.add(x, a)
return b
def fn2(x, y):
a = torch.sin(x)
with torch.no_grad():
b = torch.add(x, a)
return b
def fn3(x, y):
a = torch.sin(x)
with torch.autocast("cuda"):
b = torch.add(x, a)
return b
def fn4(x, y):
a = torch.sin(x)
torch.manual_seed(1234)
b = torch.add(x, a)
return b
def fn5(x, y):
a = torch.sin(x)
torch.add(y, 1, out=x)
b = torch.add(x, a)
return b
args = [
torch.randn(5, 5, device="cuda"),
torch.randn(5, 5, device="cuda"),
]
with unittest.mock.patch(
"torch._inductor.fx_passes.pre_grad.pattern_matcher_passes", [test_pass]
):
for fn in (fn0, fn1, fn2, fn3, fn4, fn5):
counter = 0
expected = fn(*copy.deepcopy(args))
actual = torch.compile(fn)(*copy.deepcopy(args))
# should not match
self.assertEqual(counter, int(fn is fn0))
torch.testing.assert_close(actual, expected)
def test_remove_pointless_clones(self):
@torch.compile(fullgraph=True)
def fn(a, b):
return torch.mm(a, b).clone()
result, (code) = run_and_get_code(fn, torch.randn(8, 8), torch.randn(8, 8))
# clone would create a buf1
self.assertIn("return (buf0, )", code[0])
self.assertNotIn("async_compile.cpp", code[0])
def test_unfuse_bias_addmm(self):
args = [
torch.randn(20, device="cuda"),
torch.randn(10, 15, device="cuda"),
torch.randn(15, 20, device="cuda"),
]
@torch.compile()
def fn(inp, a, b):
return torch.ops.aten.addmm(inp, a, b)
_, (code) = run_and_get_code(fn, args[0], args[1], args[2])
FileCheck().check("extern_kernels.addmm(").run(code[0])
@torch.compile()
def fn2(inp, a, b):
return torch.nn.functional.gelu(torch.ops.aten.addmm(inp, a, b))
_, (code) = run_and_get_code(fn2, args[0], args[1], args[2])
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
@torch.compile()
def fn2(inp, a, b):
return torch.nn.functional.gelu(
torch.ops.aten.addmm(inp, a, b).unsqueeze(0)
)
# hit the view path
_, (code) = run_and_get_code(fn2, args[0], args[1], args[2])
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
def test_fuse_attention_roundtrip_pattern(self):
# are we losing anything in serialization
from torch._inductor.fx_passes.fuse_attention import _get_sfdp_patterns
global_vals = {
"aten": torch.ops.aten,
"prims": torch.ops.prims,
"torch": torch,
}
for name in dir(torch._inductor.pattern_matcher):
attr = getattr(torch._inductor.pattern_matcher, name)
if isinstance(attr, type) and issubclass(attr, (PatternExpr, _TargetExpr)):
global_vals[name] = attr
with torch._subclasses.FakeTensorMode():
for _, kwargs in _get_sfdp_patterns():
gen_kwargs = {
key: kwargs[key]
for key in (
"search_fn",
"example_inputs",
"trace_fn",
"scalar_workaround",
)
}
pattern = gen_pattern(**gen_kwargs)
pattern_pp = PatternPrettyPrinter.run(pattern)
env = global_vals.copy()
exec(pattern_pp, env)
pattern_2 = env["output"]
self.assertEqual(pattern_pp, PatternPrettyPrinter.run(pattern_2))
def test_fuse_attention_all_patterns_serialized(self):
from torch._inductor.fx_passes.fuse_attention import _get_sfdp_patterns
with torch._subclasses.FakeTensorMode():
for key, kwargs in _get_sfdp_patterns():
gen_kwargs = {
key: kwargs[key]
for key in (
"search_fn",
"example_inputs",
"trace_fn",
"scalar_workaround",
)
}
pattern = gen_pattern(**gen_kwargs)
pattern_pp = PatternPrettyPrinter.run(pattern)
search_fn_pattern = get_serialized_pattern(key)
if search_fn_pattern is None:
continue
self.assertEqual(
pattern_pp,
PatternPrettyPrinter.run(search_fn_pattern),
msg=f"Found mismatched pattern {key}. Run gen_attention_patterns.py",
)
def test_match_equivalent_function_invocations1(self):
counter = 0
test_pass = PatternMatcherPass(prevent_match_across_mutations=True)
args = [
torch.randn(20, device="cuda"),
torch.randn(10, 15, device="cuda"),
torch.randn(15, 20, device="cuda"),
]
def f0(inp, a, b):
return torch.ops.aten.addmm(inp, a, b)
def f1(inp, a, b):
return torch.ops.aten.addmm(inp, a, b, beta=1.0)
def f2(inp, a, b):
return torch.ops.aten.addmm(inp, a, b, beta=1.0, alpha=1.0)
# This graph pattern should successfully match all of the above functions
@register_graph_pattern(
CallFunction(
torch.ops.aten.addmm,
Arg(),
Arg(),
Arg(),
beta=KeywordArg("beta"),
alpha=KeywordArg("alpha"),
),
pass_dict=test_pass,
)
def addmm_replacement(match: Match, inp, mat1, mat2, beta, alpha):
nonlocal counter
counter += 1
def repl(inp, x1, x2):
return (x1 @ x2) * alpha + inp * beta
with V.fake_mode:
match.replace_by_example(repl, [inp, mat1, mat2])
with unittest.mock.patch(
"torch._inductor.fx_passes.post_grad.pass_patterns",
torch._inductor.fx_passes.post_grad.pass_patterns + [test_pass],
):
for fn in (f0, f1, f2):
counter = 0
expected = fn(*copy.deepcopy(args))
opt_fn = torch.compile(fn)
actual, (code) = run_and_get_code(opt_fn, args[0], args[1], args[2])
# pattern should match
self.assertEqual(counter, 1)
torch.testing.assert_close(actual, expected)
# addmm should be replaced
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
def test_match_equivalent_function_invocations2(self):
counter = 0
test_pass = PatternMatcherPass(prevent_match_across_mutations=True)
args = [
torch.randn(20, device="cuda"),
torch.randn(10, 15, device="cuda"),
torch.randn(15, 20, device="cuda"),
]
def f0(inp, a, b):
return torch.ops.aten.addmm(inp, a, b)
def f1(inp, a, b):
return torch.ops.aten.addmm(inp, a, b, beta=1.0)
def f2(inp, a, b):
return torch.ops.aten.addmm(inp, a, b, beta=1.0, alpha=1.0)
# This graph pattern should only match f0
@register_graph_pattern(
CallFunction(torch.ops.aten.addmm, Arg(), Arg(), Arg()),
pass_dict=test_pass,
)
def addmm_replacement(match: Match, inp, mat1, mat2):
nonlocal counter
counter += 1
def repl(inp, x1, x2):
return x1 @ x2 + inp
with V.fake_mode:
match.replace_by_example(repl, [inp, mat1, mat2])
with unittest.mock.patch(
"torch._inductor.fx_passes.post_grad.pass_patterns",
torch._inductor.fx_passes.post_grad.pass_patterns + [test_pass],
):
for fn in (f0, f1, f2):
counter = 0
expected = fn(*copy.deepcopy(args))
actual = torch.compile(fn)(*copy.deepcopy(args))
self.assertEqual(counter, 1)
torch.testing.assert_close(actual, expected)
def test_match_equivalent_function_invocations3(self):
counter = 0
test_pass = PatternMatcherPass(prevent_match_across_mutations=True)
args = [
torch.randn(20, device="cuda"),
torch.randn(10, 15, device="cuda"),
torch.randn(15, 20, device="cuda"),
]
def f0(inp, a, b):
return torch.ops.aten.addmm(inp, a, b)
def f1(inp, a, b):
return torch.ops.aten.addmm(inp, a, b, beta=1.0)
def f2(inp, a, b):
return torch.ops.aten.addmm(inp, a, b, beta=1.0, alpha=1.0)
# This graph pattern should only match f1
@register_graph_pattern(
CallFunction(
torch.ops.aten.addmm, Arg(), Arg(), Arg(), beta=KeywordArg("beta")
),
pass_dict=test_pass,
)
def addmm_replacement(match: Match, inp, mat1, mat2, beta):
nonlocal counter
counter += 1
def repl(inp, x1, x2):
return x1 @ x2 + inp
with V.fake_mode:
match.replace_by_example(repl, [inp, mat1, mat2])
with unittest.mock.patch(
"torch._inductor.fx_passes.post_grad.pass_patterns",
torch._inductor.fx_passes.post_grad.pass_patterns + [test_pass],
):
for fn in (f0, f1, f2):
counter = 0
expected = fn(*copy.deepcopy(args))
actual = torch.compile(fn)(*copy.deepcopy(args))
self.assertEqual(counter, 1)
torch.testing.assert_close(actual, expected)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA:
run_tests()