blob: 37c61e28d9c7bf1e0a772b191523757517c3e9d6 [file] [log] [blame]
# Owner(s): ["module: inductor"]
import unittest
import torch
import torch._inductor.config as inductor_config
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import rand_strided
from torch._inductor.fx_passes.pad_mm import get_alignment_size, get_padded_length
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal.inductor_utils import HAS_CUDA
class PadMMTest(TestCase):
@inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
def test_pad_mm_dyn_m(self):
M = 40
K1 = 581
K2 = 49
N = 30
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = rand_strided(
(K2, N), (1, K2), device="cuda", dtype=torch.float32
)
def forward(self, a):
a1 = torch.narrow(a, 1, 0, K2)
return torch.mm(a1, self.w)
fn = Model().cuda()
a = rand_strided((M, K1), (K1, 1), device="cuda", dtype=torch.float32)
aligned_k = get_padded_length(K2, get_alignment_size(a)) + K2
torch._dynamo.mark_dynamic(a, 0)
with unittest.mock.patch(
"torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
):
res1 = fn(a)
compiled_fn = torch.compile(fn)
res2, (code,) = run_and_get_code(compiled_fn, a)
FileCheck().check(f"K = {aligned_k}").run(code)
self.assertEqual(res1, res2)
@inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
def test_cat_pad_mm_dyn_m(self):
M1 = 128
M2 = 40
K1 = 129
K2 = 111
N = 100
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = rand_strided(
(K2, N), (1, K2), device="cuda", dtype=torch.float32
)
def forward(self, a, b):
c = torch.cat([a, b], dim=0)
a1 = torch.narrow(c, 1, 0, K2)
return torch.mm(a1, self.w)
fn = Model().cuda()
a = rand_strided((M1, K1), (K1, 1), device="cuda", dtype=torch.float32)
b = rand_strided((M2, K1), (K1, 1), device="cuda", dtype=torch.float32)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(b, 0)
aligned_k = get_padded_length(K2, get_alignment_size(a)) + K2
with unittest.mock.patch(
"torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
):
res1 = fn(a, b)
compiled_fn = torch.compile(fn)
res2, (code,) = run_and_get_code(compiled_fn, a, b)
FileCheck().check(f"K = {aligned_k}").run(code)
self.assertEqual(res1, res2)
@inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
def test_pad_mm_dyn_n(self):
M = 20
K = 81
N = 30
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b):
return torch.mm(a, b)
fn = Model().cuda()
a = rand_strided((M, K), (K, 1), device="cuda", dtype=torch.float32)
b = rand_strided((K, N), (1, K), device="cuda", dtype=torch.float32)
aligned_k = get_padded_length(K, get_alignment_size(a)) + K
torch._dynamo.mark_dynamic(b, 1)
with unittest.mock.patch(
"torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
):
res1 = fn(a, b)
compiled_fn = torch.compile(fn)
res2, (code,) = run_and_get_code(compiled_fn, a, b)
FileCheck().check(f"K = {aligned_k}").run(code)
self.assertEqual(res1, res2)
@inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
def test_pad_mm_dyn_k(self):
M = 21
K = 80
N = 30
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b):
return torch.mm(a, b)
fn = Model().cuda()
a = rand_strided((M, K), (K, 1), device="cuda", dtype=torch.float32)
b = rand_strided((K, N), (1, K), device="cuda", dtype=torch.float32)
aligned_m = get_padded_length(M, get_alignment_size(a)) + M
torch._dynamo.mark_dynamic(a, 1)
torch._dynamo.mark_dynamic(b, 0)
with unittest.mock.patch(
"torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
):
res1 = fn(a, b)
compiled_fn = torch.compile(fn)
res2, (code,) = run_and_get_code(compiled_fn, a, b)
FileCheck().check(f"M = {aligned_m}").run(code)
self.assertEqual(res1, res2)
def test_pad_mm_dyn_mnk(self):
M = 20
K = 81
N = 30
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b):
return torch.mm(a, b)
fn = Model().cuda()
a = rand_strided((M, K), (K, 1), device="cuda", dtype=torch.float32)
b = rand_strided((K, N), (1, K), device="cuda", dtype=torch.float32)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(a, 1)
torch._dynamo.mark_dynamic(b, 0)
torch._dynamo.mark_dynamic(b, 1)
with unittest.mock.patch(
"torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
):
res1 = fn(a, b)
compiled_fn = torch.compile(fn)
res2, (code,) = run_and_get_code(compiled_fn, a, b)
self.assertEqual(res1, res2)
@inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
def test_pad_bmm_dyn_b(self):
B = 10
M = 128
K = 33
N = 40
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b):
return torch.bmm(a, b)
fn = Model().cuda()
a = torch.randn(B, M, K, device="cuda", dtype=torch.float32)
b = torch.randn(B, K, N, device="cuda", dtype=torch.float32)
aligned_k = get_padded_length(K, get_alignment_size(a)) + K
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(b, 0)
with unittest.mock.patch(
"torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
):
res1 = fn(a, b)
compiled_fn = torch.compile(fn)
res2, (code,) = run_and_get_code(compiled_fn, a, b)
FileCheck().check(f"K = {aligned_k}").run(code)
self.assertEqual(res1, res2)
@inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
def test_pad_bmm_dyn_k(self):
B = 10
M = 128
K = 40
N = 41
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b):
return torch.bmm(a, b)
fn = Model().cuda()
a = torch.randn(B, M, K, device="cuda", dtype=torch.float32)
b = torch.randn(B, K, N, device="cuda", dtype=torch.float32)
aligned_n = get_padded_length(N, get_alignment_size(b)) + N
torch._dynamo.mark_dynamic(a, 2)
torch._dynamo.mark_dynamic(b, 1)
with unittest.mock.patch(
"torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
):
res1 = fn(a, b)
compiled_fn = torch.compile(fn)
res2, (code,) = run_and_get_code(compiled_fn, a, b)
FileCheck().check(f"N = {aligned_n}").run(code)
self.assertEqual(res1, res2)
@inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
def test_pad_bmm_dyn_bm(self):
B = 10
M = 128
K = 40
N = 41
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b):
return torch.bmm(a, b)
fn = Model().cuda()
a = torch.randn(B, M, K, device="cuda", dtype=torch.float32)
b = torch.randn(B, K, N, device="cuda", dtype=torch.float32)
aligned_n = get_padded_length(N, get_alignment_size(b)) + N
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(a, 1)
torch._dynamo.mark_dynamic(b, 0)
with unittest.mock.patch(
"torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
):
res1 = fn(a, b)
compiled_fn = torch.compile(fn)
res2, (code,) = run_and_get_code(compiled_fn, a, b)
FileCheck().check(f"N = {aligned_n}").run(code)
self.assertEqual(res1, res2)
@inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
def test_pad_addmm_dyn_m(self):
M = 128
K = 33
N = 40
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b, c):
return torch.addmm(a, b, c)
fn = Model().cuda()
a = torch.randn(M, N, device="cuda", dtype=torch.float32)
b = torch.randn(M, K, device="cuda", dtype=torch.float32)
c = torch.randn(K, N, device="cuda", dtype=torch.float32)
aligned_k = get_padded_length(K, get_alignment_size(b)) + K
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(b, 0)
with unittest.mock.patch(
"torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
):
res1 = fn(a, b, c)
compiled_fn = torch.compile(fn)
res2, (code,) = run_and_get_code(compiled_fn, a, b, c)
FileCheck().check(f"K = {aligned_k}").run(code)
self.assertEqual(res1, res2)
@inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
def test_pad_addmm_dyn_mn(self):
M = 128
K = 33
N = 40
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b, c):
return torch.addmm(a, b, c)
fn = Model().cuda()
a = torch.randn(M, N, device="cuda", dtype=torch.float32)
b = torch.randn(M, K, device="cuda", dtype=torch.float32)
c = torch.randn(K, N, device="cuda", dtype=torch.float32)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(a, 1)
torch._dynamo.mark_dynamic(b, 0)
torch._dynamo.mark_dynamic(c, 1)
with unittest.mock.patch(
"torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
):
res1 = fn(a, b, c)
compiled_fn = torch.compile(fn)
res2, (code,) = run_and_get_code(compiled_fn, a, b, c)
# no padding
FileCheck().check(f"K = {K}").run(code)
self.assertEqual(res1, res2)
if __name__ == "__main__":
if HAS_CUDA:
run_tests()