| # Owner(s): ["module: dynamo"] |
| import functools |
| import unittest |
| from unittest.mock import patch |
| import torch |
| from torch._C import FileCheck |
| import torch._dynamo |
| import torch._dynamo.test_case |
| from torch._dynamo.utils import same |
| from torch._dynamo.testing import CompileCounter |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch.testing._internal.common_distributed import ( |
| DynamoDistributedSingleProcTestCase, |
| DynamoDistributedMultiProcTestCase, |
| _dynamo_dist_per_rank_init, |
| requires_nccl, |
| skip_if_lt_x_gpu |
| ) |
| from torch._inductor.compile_fx import compile_fx as inductor_compile_fx |
| from torch._inductor.utils import has_triton, run_and_get_triton_code |
| import torch._dynamo.logging |
| |
| # LOL if you don't remember to import this, then the op isn't registered and it hits |
| # the no-op C++ kernel that i am forced to implement despite not using it |
| import torch.distributed._functional_collectives |
| |
| |
| @requires_nccl() |
| class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase): |
| """ |
| Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under |
| """ |
| def get_world_trs(self): |
| return { |
| "tag": "", |
| "ranks": list(range(self.world_size)), |
| "group_size": self.world_size, |
| } |
| |
| @property |
| def world_size(self) -> int: |
| # hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2 |
| # works around issue with skipif<2 and workers with unpredictable #s gpu |
| return 2 |
| |
| @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") |
| @skip_if_lt_x_gpu(2) |
| # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor |
| @patch.object(torch._inductor.config, "compile_threads", 1) |
| def test_allreduce_inductor(self): |
| """ |
| This is matmul/cat/allreduce is a pattern we aim to optimize. |
| """ |
| |
| def matmul_cat_col(a, b, c, d, e, f, *, tag, ranks, group_size): |
| x = torch.matmul(a, b) |
| y = torch.matmul(c, d) |
| z = torch.cat((x, y)) |
| ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size) |
| g = torch.matmul(e, f) |
| ar = torch.ops.c10d_functional.wait_tensor(ar) |
| out = torch.add(ar, g.repeat(2, 1)) |
| return (out, ) |
| |
| def compile(func, example_inputs): |
| graph = make_fx(func)(*example_inputs) |
| return inductor_compile_fx(graph, example_inputs) |
| |
| with _dynamo_dist_per_rank_init(self.rank, self.world_size): |
| |
| matmul_cat_col = functools.partial( |
| matmul_cat_col, |
| **self.get_world_trs(), |
| ) |
| inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 6 |
| |
| eager_out = matmul_cat_col(*inputs) |
| compiled_matmul_cat_col = compile(matmul_cat_col, inputs) |
| inductor_out = compiled_matmul_cat_col(*inputs) |
| assert same(eager_out, inductor_out, tol=0.001) |
| |
| @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") |
| @skip_if_lt_x_gpu(2) |
| # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor |
| @patch.object(torch._inductor.config, "compile_threads", 1) |
| def test_eager_allreduce_inductor_wait(self): |
| |
| def eager_func(a, b, c, d, *, tag, ranks, group_size): |
| x = torch.matmul(a, b) |
| y = torch.matmul(c, d) |
| z = torch.cat((x, y)) |
| ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size) |
| return ar |
| |
| def inductor_func(ar, e, f): |
| g = torch.matmul(e, f) |
| ar = torch.ops.c10d_functional.wait_tensor(ar) |
| out = torch.add(ar, g.repeat(2, 1)) |
| return (out, ) |
| |
| def compile(func, example_inputs): |
| graph = make_fx(func)(*example_inputs) |
| return inductor_compile_fx(graph, example_inputs) |
| |
| with _dynamo_dist_per_rank_init(self.rank, self.world_size): |
| |
| eager_func = functools.partial( |
| eager_func, |
| **self.get_world_trs(), |
| ) |
| eager_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 4 |
| inductor_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2 |
| |
| eager_out = inductor_func(eager_func(*eager_inputs), *inductor_inputs) |
| compiled_inductor_func = compile(inductor_func, [eager_func(*eager_inputs)] + list(inductor_inputs)) |
| inductor_out = compiled_inductor_func(eager_func(*eager_inputs), *inductor_inputs) |
| print(f"eager_out, {eager_out}") |
| print(f"inductor_out, {inductor_out}") |
| assert same(eager_out, inductor_out, tol=0.001) |
| |
| @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") |
| @skip_if_lt_x_gpu(2) |
| # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor |
| @patch.object(torch._inductor.config, "compile_threads", 1) |
| def test_inductor_allreduce_eager_wait(self): |
| |
| def inductor_func(a, b, c, d, *, tag, ranks, group_size): |
| x = torch.matmul(a, b) |
| y = torch.matmul(c, d) |
| z = torch.cat((x, y)) |
| ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size) |
| return ar |
| |
| def eager_func(ar, e, f): |
| g = torch.matmul(e, f) |
| ar = torch.ops.c10d_functional.wait_tensor(ar) |
| out = torch.add(ar, g.repeat(2, 1)) |
| return (out, ) |
| |
| def compile(func, example_inputs): |
| graph = make_fx(func)(*example_inputs) |
| return inductor_compile_fx(graph, example_inputs) |
| |
| with _dynamo_dist_per_rank_init(self.rank, self.world_size): |
| |
| inductor_func = functools.partial( |
| inductor_func, |
| **self.get_world_trs(), |
| ) |
| inductor_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 4 |
| eager_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2 |
| |
| eager_out = eager_func(inductor_func(*inductor_inputs), *eager_inputs) |
| compiled_inductor_func = compile(inductor_func, inductor_inputs) |
| inductor_out = eager_func(compiled_inductor_func(*inductor_inputs), *eager_inputs) |
| assert same(eager_out, inductor_out, tol=0.001) |
| |
| @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") |
| @skip_if_lt_x_gpu(2) |
| # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor |
| @patch.object(torch._inductor.config, "compile_threads", 1) |
| def test_allgather_into_tensor_inductor(self): |
| """ |
| This is matmul/cat/allreduce is a pattern we aim to optimize. |
| """ |
| |
| def example(a, b, *, tag, ranks, group_size): |
| c = torch.matmul(a, b) |
| ag = torch.ops.c10d_functional.all_gather_into_tensor(c, tag, ranks, group_size) |
| ag = torch.ops.c10d_functional.wait_tensor(ag) |
| return (ag, ) |
| |
| def compile(func, example_inputs): |
| graph = make_fx(func)(*example_inputs) |
| return inductor_compile_fx(graph, example_inputs) |
| |
| with _dynamo_dist_per_rank_init(self.rank, self.world_size): |
| |
| example = functools.partial( |
| example, |
| **self.get_world_trs(), |
| ) |
| inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2 |
| |
| eager_out = example(*inputs) |
| compiled_matmul_cat_col = compile(example, inputs) |
| inductor_out = compiled_matmul_cat_col(*inputs) |
| assert same(eager_out, inductor_out, tol=0.001) |
| |
| @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") |
| @skip_if_lt_x_gpu(2) |
| # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor |
| @patch.object(torch._inductor.config, "compile_threads", 1) |
| def test_reduce_scatter_tensor_inductor(self): |
| def example(a, b, *, tag, ranks, group_size): |
| c = torch.matmul(a, b) |
| ag = torch.ops.c10d_functional.reduce_scatter_tensor( |
| c, "sum", tag, ranks, group_size |
| ) |
| ag = torch.ops.c10d_functional.wait_tensor(ag) |
| return (ag,) |
| |
| def compile(func, example_inputs): |
| graph = make_fx(func)(*example_inputs) |
| return inductor_compile_fx(graph, example_inputs) |
| |
| with _dynamo_dist_per_rank_init(self.rank, self.world_size): |
| example = functools.partial( |
| example, |
| **self.get_world_trs(), |
| ) |
| inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2 |
| |
| eager_out = example(*inputs) |
| compiled_fn = compile(example, inputs) |
| inductor_out = compiled_fn(*inputs) |
| assert same(eager_out, inductor_out, tol=0.001) |
| |
| |
| @requires_nccl() |
| class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): |
| """ |
| Prefer single-proc test runner for basic tests as it is easier to work with. |
| """ |
| def get_world_trs(self, world_size=1): |
| return { |
| "tag": "", |
| "ranks": list(range(world_size)), |
| "group_size": world_size, |
| } |
| |
| @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") |
| def test_inductor_single_op(self): |
| torch._inductor.config.debug = True |
| |
| def func(inp, *, tag, ranks, group_size): |
| ar = torch.ops.c10d_functional.all_reduce(inp, "sum", tag, ranks, group_size) |
| ar = torch.ops.c10d_functional.wait_tensor(ar) |
| return ar |
| |
| inputs = torch.ones(4, 4, device="cuda") |
| |
| compiled = torch.compile(func) |
| out = compiled(inputs, **self.get_world_trs()) |
| code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) |
| FileCheck() \ |
| .check("buf0 = empty_strided") \ |
| .check("buf0.copy_(arg0_1)") \ |
| .check("buf0_work = dist.all_reduce(buf0") \ |
| .check("_register_tensor_work(buf0, buf0_work)") \ |
| .check("_wait_tensor(buf0)") \ |
| .check("return (buf1, )") \ |
| .run(code) |
| correct = func(inputs, **self.get_world_trs()) |
| assert same(out, correct) |
| |
| @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") |
| def test_inductor_steal_buffer(self): |
| """ |
| it's ok and optimal if inductor allreduce mutates the buffer of an intermediate |
| that isn't going to be used again |
| """ |
| torch._inductor.config.debug = True |
| |
| def func(inp, *, tag, ranks, group_size): |
| x = inp + 1 |
| ar = torch.ops.c10d_functional.all_reduce(x, "sum", tag, ranks, group_size) |
| ar = torch.ops.c10d_functional.wait_tensor(ar) |
| # ensure other is not incorrectly aliasing ar's buffer |
| other = torch.ones_like(inp) + 22 |
| return ar, other |
| |
| inputs = torch.ones(4, 4, device="cuda") |
| |
| compiled = torch.compile(func) |
| code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) |
| FileCheck() \ |
| .check("buf1 = buf0; del buf0 # reuse") \ |
| .check_not("buf1.copy_(") \ |
| .check("buf1_work = dist.all_reduce(buf1") \ |
| .check("_register_tensor_work(buf1, buf1_work)") \ |
| .check("_wait_tensor(buf1)") \ |
| .check("buf2 = buf1") \ |
| .check("buf3 = empty_strided") \ |
| .check("return (buf2, buf3") \ |
| .run(code) |
| out = compiled(inputs, **self.get_world_trs()) |
| correct = func(inputs, **self.get_world_trs()) |
| assert same(out, correct) |
| |
| @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") |
| @patch.object(torch._inductor.config.triton, "descriptive_names", False) |
| def test_inductor_doesnt_mutate_shared(self): |
| """ |
| make sure that an intermediate that's going to be reuse isn't mutated unless copied |
| """ |
| torch._inductor.config.debug = True |
| |
| def func(inp, *, tag, ranks, group_size): |
| x = inp + 1 |
| ar = torch.ops.c10d_functional.all_reduce(x, "sum", tag, ranks, group_size) |
| y = x + 2 |
| ar = torch.ops.c10d_functional.wait_tensor(ar) |
| # ensure other is not incorrectly aliasing ar's buffer |
| other = torch.ones_like(inp) + 22 |
| return ar, y, other |
| |
| inputs = torch.ones(4, 4, device="cuda") |
| |
| compiled = torch.compile(func) |
| code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs()) |
| FileCheck() \ |
| .check("buf0 = empty_strided(") \ |
| .check("buf2 = empty_strided") \ |
| .check("triton_poi__0.run(arg0_1, buf0, buf2") \ |
| .check_not("copy_(") \ |
| .check("buf1 = buf0; del buf0 # reuse") \ |
| .check("buf1_work = dist.all_reduce(buf1") \ |
| .check("_register_tensor_work(buf1, buf1_work)") \ |
| .check("_wait_tensor(buf1)") \ |
| .check("buf3 = buf1") \ |
| .check("return (buf3, buf2, buf4") \ |
| .run(code) |
| out = compiled(inputs, **self.get_world_trs()) |
| correct = func(inputs, **self.get_world_trs()) |
| assert same(out, correct) |
| |
| def test_dynamo_trace_allreduce(self): |
| def func(inp, *, tag, ranks, group_size): |
| ar = torch.ops.c10d_functional.all_reduce(inp, "sum", tag, ranks, group_size) |
| return ar |
| |
| inputs = torch.ones(4, 4, device="cuda") |
| counter = CompileCounter() |
| compiled = torch.compile(func, backend=counter) |
| out = compiled(inputs, **self.get_world_trs()) |
| correct = func(inputs, **self.get_world_trs()) |
| assert counter.frame_count == 1 |
| assert counter.op_count == 1 |
| assert same(out, correct) |
| |
| def test_backwards(self): |
| """ |
| It's probably not that common to need backwards support for collectives. |
| |
| However, I wanted to at least see if it was possible to support it as a design goal. |
| """ |
| def func(inp, *, tag, ranks, group_size): |
| ar = torch.ops.c10d_functional.all_reduce(inp, "sum", tag, ranks, group_size) |
| return ar |
| |
| input = torch.ones(4, 4, device="cuda", requires_grad=True) |
| # TODO implement backwards |
| with self.assertRaisesRegex(RuntimeError, "element 0 of tensors does not require grad and does not have a grad_fn"): |
| compiled = torch.compile(func, backend="aot_eager") # inductor bug with single-op allreduce graph |
| out = compiled(input, **self.get_world_trs()) |
| out.sum().backward() |
| |
| correct_input = input.clone().detach().requires_grad_() |
| correct = func(correct_input, **self.get_world_trs()) |
| correct.sum().backward() |
| assert same(out, correct) |
| assert same(input.grad, correct_input.grad) |
| |
| def test_meta(self): |
| x = torch.rand((2, 3, 4), device="meta") |
| out = torch.ops.c10d_functional.all_reduce(x, "sum", **self.get_world_trs()) |
| assert x.size() == out.size() |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |