blob: 8de7feaabe590466aa79cd7f47b9a60f88c1608d [file] [log] [blame]
# Owner(s): ["module: inductor"]
import copy
import unittest
import torch
import torch._inductor.config as inductor_config
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import expectedFailureDynamicWrapper
from torch._dynamo.utils import count_calls, counters
from torch._inductor.fx_passes import joint_graph
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA
class TestPaternMatcher(TestCase):
def test_mm_plus_mm(self):
def fn(a, b, c, d):
return torch.add(torch.mm(a, b), torch.mm(c, d))
args_list = [
(
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
),
# https://github.com/pytorch/pytorch/issues/100670.
(
torch.randn(1, 4, device="cuda"),
torch.randn(4, 2, device="cuda"),
torch.randn(1, 2, device="cuda"),
torch.randn(2, 1, device="cuda"),
),
(
torch.randn(1, 2, device="cuda"),
torch.randn(2, 1, device="cuda"),
torch.randn(1, 4, device="cuda"),
torch.randn(4, 2, device="cuda"),
),
(
torch.randn(1, 4, device="cuda"),
torch.randn(4, 2, device="cuda"),
torch.randn(1, 5, device="cuda"),
torch.randn(5, 2, device="cuda"),
),
]
for args in args_list:
counters.clear()
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3)
def test_addmm(self):
def fn(a, b, c):
return torch.add(a, torch.mm(b, c)), torch.mm(b, c) + a
args_list = [
(
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
),
(
torch.randn(16, 16, device="cuda"),
torch.randn(1, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
),
(
torch.randn(1, 16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
),
(4, torch.randn(16, 16, device="cuda"), torch.randn(16, 16, device="cuda")),
]
for args in args_list:
torch._dynamo.reset()
counters.clear()
e1, e2 = fn(*args)
a1, a2 = torch.compile(fn)(*args)
torch.testing.assert_close(a1, e1)
torch.testing.assert_close(a2, e2)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
def test_addmm_activation(self):
def fn_addmm_relu(input, mat1, mat2):
return torch.nn.functional.relu(torch.addmm(input, mat1, mat2))
def fn_addmm_gelu(input, mat1, mat2):
return torch.nn.functional.gelu(torch.addmm(input, mat1, mat2))
args = [
torch.randn(20, device="cuda"), # input
torch.randn(10, 15, device="cuda"), # mat1
torch.randn(15, 20, device="cuda"), # mat2
]
for fn, atol in (
(fn_addmm_relu, 1e-8),
# higher tolerance due to the "tanh" approximation
# in fused GELU epilogue vs. "none" without fusion
(fn_addmm_gelu, 1e-3),
):
expected = fn(*args)
actual, (code,) = run_and_get_code(torch.compile(fn), *args)
torch.testing.assert_close(actual, expected, atol=atol, rtol=0)
self.assertTrue("_addmm_activation" in code)
for fn in (fn_addmm_relu, fn_addmm_gelu):
counters.clear()
torch.compile(
fn,
# replacement disabled on max_autotune_gemm
options={"max_autotune_gemm": True},
)(*args)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
args_not_replaced = [
# addmm + activation with a rank-2 input
# is not fusable, hence not replaced
torch.randn(10, 20, device="cuda"), # input
torch.randn(10, 15, device="cuda"), # mat1
torch.randn(15, 20, device="cuda"), # mat2
]
for fn in (fn_addmm_relu, fn_addmm_gelu):
counters.clear()
torch.compile(fn)(*args_not_replaced)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
def test_cat_mm(self):
def fn(a, b, c):
return torch.cat(
[
torch.mm(a, b),
torch.mm(b, c),
torch.mm(a, c),
],
1,
)
args = [
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 5)
def test_cat_addmm(self):
def fn(a, b, c):
return torch.cat(
[
torch.addmm(a, b, c),
torch.addmm(b, c, a),
torch.addmm(c, a, b),
],
1,
)
args = [
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 5)
@expectedFailureDynamicWrapper
def test_cat_slice_cat(self):
def check_counter(counter, expected):
if not inductor_config.cpp_wrapper:
self.assertEqual(counter, expected)
else:
# cpp_wrapper for the CUDA backend runs two passes
self.assertEqual(counter, 2 * expected)
def fn(a, b):
cat_1 = torch.ops.aten.cat.default([a, b], 1)
slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19)
return torch.ops.aten.cat.default([cat_1, slice_2], 1)
args = [
torch.randn(2, 32, device="cuda"),
torch.randn(2, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
check_counter(counters["inductor"]["pattern_matcher_count"], 1)
check_counter(counters["inductor"]["pattern_matcher_nodes"], 4)
counters.clear()
args = [
torch.randn(2, 8, device="cuda"),
torch.randn(2, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
check_counter(counters["inductor"]["pattern_matcher_count"], 1)
check_counter(counters["inductor"]["pattern_matcher_nodes"], 4)
# Verify we fallback to non-optimal path for negative `end`.
def fn(a, b):
cat_1 = torch.ops.aten.cat.default([a, b], 1)
slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, -1)
return torch.ops.aten.cat.default([cat_1, slice_2], 1)
counters.clear()
args = [
torch.randn(2, 8, device="cuda"),
torch.randn(2, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
check_counter(counters["inductor"]["pattern_matcher_count"], 1)
check_counter(counters["inductor"]["pattern_matcher_nodes"], 4)
def test_pointless_convert(self):
def fn1(x):
x = torch.ops.prims.convert_element_type.default(x, torch.float16)
x = torch.ops.prims.convert_element_type.default(x, torch.float32)
return x
gm = torch.fx.symbolic_trace(fn1)
self.assertEqual(count_calls(gm.graph), 2)
joint_graph.joint_graph_passes(gm)
self.assertEqual(count_calls(gm.graph), 1)
def fn2(x):
x = torch.ops.prims.convert_element_type.default(x, torch.int32)
x = torch.ops.prims.convert_element_type.default(x, torch.float32)
return x
gm = torch.fx.symbolic_trace(fn2)
self.assertEqual(count_calls(gm.graph), 2)
joint_graph.joint_graph_passes(gm)
self.assertEqual(count_calls(gm.graph), 2)
def test_pointless_cumsum(self):
def fn1():
ones = torch.full(
[1, 128], 1, layout=torch.strided, dtype=torch.float32
).to(torch.int64)
return torch.cumsum(ones, 1) * ones
def fn2():
ones = torch.full(
[55, 10], 1, layout=torch.strided, dtype=torch.float32
).to(torch.int64)
return torch.cumsum(ones, 1)
for fn in (fn1, fn2):
result, (code,) = run_and_get_code(torch.compile(fn, fullgraph=True))
self.assertNotIn("aten.cumsum", code)
self.assertEqual(result, fn())
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
counters.clear()
def test_splitwithsizes_cat(self):
# Good case
def fn(a):
split_with_sizes = torch.ops.aten.split_with_sizes.default(a, [8, 24], 1)
getitem = split_with_sizes[0]
getitem_1 = split_with_sizes[1]
cat = torch.ops.aten.cat.default([getitem, getitem_1], 1)
return cat**2
args = [
torch.randn(2, 32, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
counters.clear()
# Not all getitems are passed to cat
def fn(a):
split_with_sizes = torch.ops.aten.split_with_sizes.default(a, [8, 8, 16], 1)
getitem = split_with_sizes[0]
getitem_1 = split_with_sizes[1]
getitem_2 = split_with_sizes[2]
cat = torch.ops.aten.cat.default([getitem, getitem_1], 1)
return cat**2 + getitem_2
args = [
torch.randn(2, 32, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
counters.clear()
# Different dimensions (TODO this case should be handled by replacing with a reshape)
def fn(a):
split_with_sizes = torch.ops.aten.split_with_sizes.default(
a, [8, 8, 8, 8], 1
)
cat = torch.ops.aten.cat.default(split_with_sizes, 0)
return cat**2
args = [
torch.randn(2, 32, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
# https://github.com/pytorch/pytorch/issues/99686.
def fn(a):
x = torch.ops.aten.split_with_sizes.default(a, [3, 2, 3], dim=1)
cat = torch.ops.aten.cat.default([x[1], x[0], x[2]], dim=1)
return cat
args = [
torch.randn(1, 8, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
def test_match_with_mutation(self):
from torch._inductor.pattern_matcher import (
CallFunction,
KeywordArg,
PatternMatcherPass,
register_graph_pattern,
)
counter = 0
test_pass = PatternMatcherPass(prevent_match_across_mutations=True)
@register_graph_pattern(
CallFunction(
torch.add, KeywordArg("x"), CallFunction(torch.sin, KeywordArg("x"))
),
pass_dict=test_pass,
)
def _test(match, x):
nonlocal counter
counter += 1
def fn0(x, y):
a = torch.sin(x)
b = torch.add(x, a)
return b
def fn1(x, y):
a = torch.sin(x)
x.copy_(y)
b = torch.add(x, a)
return b
def fn2(x, y):
a = torch.sin(x)
with torch.no_grad():
b = torch.add(x, a)
return b
def fn3(x, y):
a = torch.sin(x)
with torch.autocast("cuda"):
b = torch.add(x, a)
return b
def fn4(x, y):
a = torch.sin(x)
torch.manual_seed(1234)
b = torch.add(x, a)
return b
def fn5(x, y):
a = torch.sin(x)
torch.add(y, 1, out=x)
b = torch.add(x, a)
return b
args = [
torch.randn(5, 5, device="cuda"),
torch.randn(5, 5, device="cuda"),
]
with unittest.mock.patch(
"torch._inductor.fx_passes.pre_grad.pattern_matcher_passes", [test_pass]
):
for fn in (fn0, fn1, fn2, fn3, fn4, fn5):
counter = 0
expected = fn(*copy.deepcopy(args))
actual = torch.compile(fn)(*copy.deepcopy(args))
# should not match
self.assertEqual(counter, int(fn is fn0))
torch.testing.assert_close(actual, expected)
def test_remove_pointless_clones(self):
@torch.compile(fullgraph=True)
def fn(a, b):
return torch.mm(a, b).clone()
result, (code) = run_and_get_code(fn, torch.randn(8, 8), torch.randn(8, 8))
# clone would create a buf1
self.assertIn("return (buf0, )", code[0])
self.assertNotIn("async_compile.cpp", code[0])
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA:
run_tests()