blob: b1686a1e973b8e37440583be07cb597419fad0f2 [file] [log] [blame]
# Owner(s): ["module: inductor"]
import torch
from torch import multiprocessing as mp
from torch._dynamo.test_case import run_tests, TestCase
from torch._inductor import config
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import Buffer, FixedLayout
from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
from torch._inductor.select_algorithm import AlgorithmSelectorCache, ChoiceCaller
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
torch.set_float32_matmul_precision("high")
if HAS_CUDA:
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
def benchmark_choice(choice, args, out, expected_out, timings):
result = choice.benchmark(*args, out=out)
if expected_out is not None:
torch.testing.assert_close(out, expected_out)
timings.copy_(torch.tensor(result))
class FailChoiceCaller(ChoiceCaller):
def benchmark(self, *args, out):
raise RuntimeError("This choice caller will always throw")
@instantiate_parametrized_tests
class TestDoBench(TestCase):
def _create_buffer(self, name, shape):
return Buffer(name, FixedLayout(torch.device("cuda:0"), torch.float32, shape))
def test_benchmark_choice_in_subproc(self):
gm = make_fx(
lambda: torch.zeros(2, 3)
)() # a dummy graph to construct the GraphLowering
graph = GraphLowering(gm)
# the graph handler is neede to create benchmark example value below
with V.set_graph_handler(graph):
buf1 = self._create_buffer("mat1", (2, 3))
buf2 = self._create_buffer("mat2", (3, 2))
buf3 = self._create_buffer("mat3", (2, 3))
buf4 = self._create_buffer("mat4", (3, 2))
layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2))
mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1)
mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2)
mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3)
mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4)
out = AlgorithmSelectorCache.benchmark_example_value(layout)
# expected_out = (mat1 @ mat2) + (mat3 @ mat4)
expected_out = None
choice = aten_mm_plus_mm.bind((buf1, buf2, buf3, buf4), layout)
# use a tensor since the mutation to a python list in a sub process
# is not synced back to the parent process
timings = torch.zeros(3, dtype=torch.float32)
ctx = mp.get_context("spawn")
child = ctx.Process(
target=benchmark_choice,
args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings),
)
child.start()
child.join()
self.assertEqual(0, child.exitcode)
print(f"timings is {timings}, out {out}, expected_out {expected_out}")
def test_benchmark_choice_fail_in_subproc(self):
gm = make_fx(
lambda: torch.zeros(2, 3)
)() # a dummy graph to construct the GraphLowering
graph = GraphLowering(gm)
# the graph handler is neede to create benchmark example value below
with V.set_graph_handler(graph):
buf1 = self._create_buffer("mat1", (2, 3))
buf2 = self._create_buffer("mat2", (3, 2))
buf3 = self._create_buffer("mat3", (2, 3))
buf4 = self._create_buffer("mat4", (3, 2))
layout = FixedLayout(torch.device("cuda:0"), torch.float32, (2, 2))
mat1 = AlgorithmSelectorCache.benchmark_example_value(buf1)
mat2 = AlgorithmSelectorCache.benchmark_example_value(buf2)
mat3 = AlgorithmSelectorCache.benchmark_example_value(buf3)
mat4 = AlgorithmSelectorCache.benchmark_example_value(buf4)
out = AlgorithmSelectorCache.benchmark_example_value(layout)
expected_out = (mat1 @ mat2) + (mat3 @ mat4)
choice = FailChoiceCaller("fail_choice_caller", [], None)
# use a tensor since python list is not synced back
timings = torch.zeros(3, dtype=torch.float32)
ctx = mp.get_context("spawn")
child = ctx.Process(
target=benchmark_choice,
args=(choice, (mat1, mat2, mat3, mat4), out, expected_out, timings),
)
child.start()
child.join()
self.assertNotEqual(0, child.exitcode)
@parametrize("autotune_in_subproc", (True, False))
def test_max_autotune_mm_plus_mm(self, autotune_in_subproc):
"""
This crash previously due to a triton issue: https://github.com/openai/triton/issues/1298 .
With autotuning in subprocess, we don't crash anymore.
"""
m, n, k = 2048, 1536, 64
def mm_plus_mm(a, b, c, d):
return a @ b + c @ d
a = torch.randn(m, k).cuda()
b = torch.randn(k, n).cuda()
c = torch.randn(m, k).cuda()
d = torch.randn(k, n).cuda()
with config.patch(
{"max_autotune": True, "autotune_in_subproc": autotune_in_subproc}
):
torch.compile(mm_plus_mm)(a, b, c, d)
def test_max_autotune_regular_mm(self):
"""
Make sure autotuning mm in sub processes work without crashes.
"""
def mm(a, b):
a = torch.sin(a)
return a @ b
a = torch.randn(100, 10).cuda()
b = torch.randn(10, 100).cuda()
with config.patch({"max_autotune": True, "autotune_in_subproc": True}):
torch.compile(mm)(a, b)
def test_max_autotune_addmm(self):
"""
Make sure autotuning addmm in sub processes work without crashes.
"""
def addmm(x, a, b):
return torch.addmm(x, a, b)
x = torch.randn(100).cuda()
a = torch.randn(100, 10).cuda()
b = torch.randn(10, 100).cuda()
with config.patch({"max_autotune": True, "autotune_in_subproc": True}):
torch.compile(addmm)(x, a, b)
if __name__ == "__main__":
if HAS_CUDA:
run_tests()