blob: 1fc362d547788e21d79d883917eaba8ab250f971 [file] [log] [blame]
# Owner(s): ["module: inductor"]
import torch
from torch import multiprocessing as mp
from torch._dynamo.test_case import run_tests, TestCase
from torch._inductor import config
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import Buffer, FixedLayout
from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
from torch._inductor.select_algorithm import AlgorithmSelectorCache, ChoiceCaller
from torch._inductor.utils import run_and_get_code
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skipIfRocm,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
torch.set_float32_matmul_precision("high")
if HAS_CUDA:
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
def benchmark_choice(choice, args, out, expected_out, timings):
result = choice.benchmark(*args, out=out)
if expected_out is not None:
torch.testing.assert_close(out, expected_out)
timings.copy_(torch.tensor(result))
class FailChoiceCaller(ChoiceCaller):
def benchmark(self, *args, out):
raise RuntimeError("This choice caller will always throw")
@instantiate_parametrized_tests
class TestDoBench(TestCase):
def _create_buffer(self, name, shape):
return Buffer(name, FixedLayout(torch.device("cuda:0"), torch.float32, shape))
def test_benchmark_choice_in_subproc(self):
gm = make_fx(
lambda: torch.zeros(2, 3)
)() # a dummy graph to construct the GraphLowering
graph = GraphLowering(gm)
# the graph handler is neede to create benchmark example value below
with V.set_graph_handler(graph):
buf1 = self._create_buffer("mat1", (2, 3))
buf2 = self._create_buffer("mat2", (3, 2))
buf3 = self._create_buffer("mat3", (2, 3))
buf4 = self._create_buffer("mat4", (3, 2))
layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2))
mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1)
mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2)
mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3)
mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4)
out = AlgorithmSelectorCache.benchmark_example_value(layout)
# expected_out = (mat1 @ mat2) + (mat3 @ mat4)
expected_out = None
choice = aten_mm_plus_mm.bind((buf1, buf2, buf3, buf4), layout)
# use a tensor since the mutation to a python list in a sub process
# is not synced back to the parent process
timings = torch.zeros(3, dtype=torch.float32)
ctx = mp.get_context("spawn")
child = ctx.Process(
target=benchmark_choice,
args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings),
)
child.start()
child.join()
self.assertEqual(0, child.exitcode)
print(f"timings is {timings}, out {out}, expected_out {expected_out}")
def test_benchmark_choice_fail_in_subproc(self):
gm = make_fx(
lambda: torch.zeros(2, 3)
)() # a dummy graph to construct the GraphLowering
graph = GraphLowering(gm)
# the graph handler is neede to create benchmark example value below
with V.set_graph_handler(graph):
buf1 = self._create_buffer("mat1", (2, 3))
buf2 = self._create_buffer("mat2", (3, 2))
buf3 = self._create_buffer("mat3", (2, 3))
buf4 = self._create_buffer("mat4", (3, 2))
layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2))
mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1)
mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2)
mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3)
mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4)
out = AlgorithmSelectorCache.benchmark_example_value(layout)
expected_out = (mat1 @ mat2) + (mat3 @ mat4)
choice = FailChoiceCaller("fail_choice_caller", [], None)
# use a tensor since python list is not synced back
timings = torch.zeros(3, dtype=torch.float32)
ctx = mp.get_context("spawn")
child = ctx.Process(
target=benchmark_choice,
args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings),
)
child.start()
child.join()
self.assertNotEqual(0, child.exitcode)
@parametrize("autotune_in_subproc", (True, False))
def test_max_autotune_mm_plus_mm(self, autotune_in_subproc):
"""
This crash previously due to a triton issue: https://github.com/openai/triton/issues/1298 .
With autotuning in subprocess, we don't crash anymore.
"""
m, n, k = 2048, 1536, 64
def mm_plus_mm(a, b, c, d):
return a @ b + c @ d
a = torch.randn(m, k).cuda()
b = torch.randn(k, n).cuda()
c = torch.randn(m, k).cuda()
d = torch.randn(k, n).cuda()
with config.patch(
{"max_autotune": True, "autotune_in_subproc": autotune_in_subproc}
):
torch.compile(mm_plus_mm)(a, b, c, d)
@parametrize("dynamic", (False, True))
def test_max_autotune_mm_plus_mm_zero_size_input(self, dynamic):
"""
Make sure autotuning mm_plus_mm with zero-size input works without crashes.
"""
m, n, k = 0, 1536, 64
def mm_plus_mm(a, b, c, d):
return a @ b + c @ d
a = torch.randn(m, k).cuda()
b = torch.randn(k, n).cuda()
c = torch.randn(m, k).cuda()
d = torch.randn(k, n).cuda()
with config.patch({"max_autotune": True}):
torch.compile(mm_plus_mm, dynamic=dynamic)(a, b, c, d)
@parametrize("dynamic", (False, True))
def test_max_autotune_regular_mm(self, dynamic: bool):
"""
Make sure autotuning mm in sub processes work without crashes.
"""
def mm(a, b):
a = torch.sin(a)
return a @ b
a = torch.randn(100, 10).cuda()
b = torch.randn(10, 100).cuda()
with config.patch({"max_autotune": True, "autotune_in_subproc": True}):
torch.compile(mm, dynamic=dynamic)(a, b)
@parametrize("dynamic", (False, True))
def test_max_autotune_regular_mm_zero_size_input(self, dynamic: bool):
"""
Make sure autotuning mm with zero-size input works without crashes.
"""
def mm(a, b):
a = torch.sin(a)
return a @ b
a = torch.randn(0, 10).cuda()
b = torch.randn(10, 100).cuda()
with config.patch({"max_autotune": True}):
torch.compile(mm, dynamic=dynamic)(a, b)
@parametrize("dynamic", (False, True))
def test_max_autotune_addmm(self, dynamic):
"""
Make sure autotuning addmm in sub processes work without crashes.
"""
def addmm(x, a, b):
return torch.addmm(x, a, b)
x = torch.randn(100).cuda()
a = torch.randn(100, 10).cuda()
b = torch.randn(10, 100).cuda()
with config.patch({"max_autotune": True, "autotune_in_subproc": True}):
torch.compile(addmm, dynamic=dynamic)(x, a, b)
@parametrize("dynamic", (False, True))
def test_max_autotune_addmm_zero_size_input(self, dynamic):
"""
Make sure autotuning addmm with zero-size input works without crashes.
"""
def addmm(x, a, b):
return torch.addmm(x, a, b)
x = torch.randn(100).cuda()
a = torch.randn(0, 10).cuda()
b = torch.randn(10, 100).cuda()
with config.patch({"max_autotune": True}):
torch.compile(addmm, dynamic=dynamic)(x, a, b)
@skipIfRocm
def test_autotune_conv1x1(self):
# Define the 1x1 convolutional layer
# Assuming input has 3 channels and we want to produce 16 channels as output
conv1x1 = (
torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1)
.to(memory_format=torch.channels_last)
.cuda()
)
# Example input tensor: batch size = 4, channels = 3, height = 32, width = 32
# The memory format is set to `channels_last`
input_tensor = (
torch.randn(4, 3, 32, 32)
.contiguous(memory_format=torch.channels_last)
.cuda()
)
with config.patch(
{"max_autotune": True, "max_autotune_gemm_backends": "TRITON"}
):
@torch.compile()
def foo(mod, x):
return mod(x)
with torch.no_grad():
out, code = run_and_get_code(foo, conv1x1, input_tensor)
FileCheck().check_not("extern_kernels.convolution").run(code[0])
self.assertEqual(conv1x1(input_tensor), out, atol=1e-2, rtol=0)
def test_cat_addmm(self):
def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
return torch.cat(
[
torch.addmm(a, b, c),
torch.addmm(b, c, a),
],
1,
)
args = [
torch.randn(4, 4, device="cuda"),
torch.randn(4, 4, device="cuda"),
torch.randn(4, 4, device="cuda"),
]
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "Triton",
}
):
expected = fn(*args)
actual = torch.compile(fn)(*args)
torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2)
def test_triton_template_with_epilogues_and_dynamic_shape(self):
def fn(
x: torch.Tensor, w: torch.Tensor, bias: torch.Tensor, mul: torch.Tensor
) -> torch.Tensor:
return (
torch.nn.functional.relu(
torch.matmul(torch.transpose(x, 0, 1), torch.transpose(w, 0, 1))
+ bias
)
* mul
)
M0 = 5
M1 = 8
K = 4
N = 3
w = torch.rand(N, K).cuda().half()
b = torch.rand(N).cuda().half()
with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": "Triton",
}
):
compiled_fn = torch.compile(
fn, fullgraph=True, dynamic=True, mode="max-autotune-no-cudagraphs"
)
x0 = torch.rand(K, M0).cuda().half()
mul0 = torch.rand(M0, N).cuda().half()
y0 = compiled_fn(x0, w, b, mul0)
y0_expected = fn(x0, w, b, mul0)
torch.testing.assert_close(y0, y0_expected)
x1 = torch.rand(K, M1).cuda().half()
mul1 = torch.rand(M1, N).cuda().half()
y1 = compiled_fn(x1, w, b, mul1)
y1_expected = fn(x1, w, b, mul1)
torch.testing.assert_close(y1, y1_expected)
if __name__ == "__main__":
if HAS_CUDA:
run_tests()