blob: c09df561de8f54f68c7569745389e1c3cc3d6ebf [file] [log] [blame]
# Owner(s): ["module: inductor"]
import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA
class TestPaternMatcher(TestCase):
def test_mm_plus_mm(self):
def fn(a, b, c, d):
return torch.add(torch.mm(a, b), torch.mm(c, d))
args = [
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3)
def test_addmm(self):
def fn(a, b, c):
return torch.add(a, torch.mm(b, c)), torch.mm(b, c) + a
args_list = [
(
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
),
(
torch.randn(16, 16, device="cuda"),
torch.randn(1, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
),
(
torch.randn(1, 16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
),
(4, torch.randn(16, 16, device="cuda"), torch.randn(16, 16, device="cuda")),
]
for args in args_list:
counters.clear()
e1, e2 = fn(*args)
a1, a2 = torch.compile(fn)(*args)
torch.testing.assert_close(a1, e1)
torch.testing.assert_close(a2, e2)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 2)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
def test_cat_mm(self):
def fn(a, b, c):
return torch.cat(
[
torch.mm(a, b),
torch.mm(b, c),
torch.mm(a, c),
],
1,
)
args = [
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
def test_cat_addmm(self):
def fn(a, b, c):
return torch.cat(
[
torch.addmm(a, b, c),
torch.addmm(b, c, a),
torch.addmm(c, a, b),
],
1,
)
args = [
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
def test_cat_slice_cat(self):
def fn(a, b):
cat_1 = torch.ops.aten.cat.default([a, b], 1)
slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19)
return torch.ops.aten.cat.default([cat_1, slice_2], 1)
args = [
torch.randn(2, 32, device="cuda"),
torch.randn(2, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
counters.clear()
args = [
torch.randn(2, 8, device="cuda"),
torch.randn(2, 16, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
def test_splitwithsizes_cat(self):
# Good case
def fn(a):
split_with_sizes = torch.ops.aten.split_with_sizes.default(a, [8, 24], 1)
getitem = split_with_sizes[0]
getitem_1 = split_with_sizes[1]
cat = torch.ops.aten.cat.default([getitem, getitem_1], 1)
return cat**2
args = [
torch.randn(2, 32, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 4)
counters.clear()
# Not all getitems are passed to cat
def fn(a):
split_with_sizes = torch.ops.aten.split_with_sizes.default(a, [8, 8, 16], 1)
getitem = split_with_sizes[0]
getitem_1 = split_with_sizes[1]
getitem_2 = split_with_sizes[2]
cat = torch.ops.aten.cat.default([getitem, getitem_1], 1)
return cat**2 + getitem_2
args = [
torch.randn(2, 32, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
counters.clear()
# Different dimensions (TODO this case should be handled by replacing with a reshape)
def fn(a):
split_with_sizes = torch.ops.aten.split_with_sizes.default(
a, [8, 8, 8, 8], 1
)
cat = torch.ops.aten.cat.default(split_with_sizes, 0)
return cat**2
args = [
torch.randn(2, 32, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 0)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 0)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA:
run_tests()