blob: af59f6188581d8f42359bf7f2e924f21cb33516b [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import functools
import unittest
from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.logging
import torch._dynamo.test_case
# for some reason importing functional collectives after dynamo breaks collectives handling!
import torch.distributed._functional_collectives as _functional_collectives
from torch._C import FileCheck
from torch._dynamo.testing import CompileCounter
from torch._dynamo.utils import same
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
from torch._inductor.utils import run_and_get_triton_code
from torch.distributed.distributed_c10d import GroupMember
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_distributed import (
_dynamo_dist_per_rank_init,
DynamoDistributedMultiProcTestCase,
DynamoDistributedSingleProcTestCase,
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
requires_cuda,
)
from torch.utils._triton import has_triton
def _tolist_with_constrain_as_size(tensor):
lst = tensor.tolist()
for elem in lst:
torch._check_is_size(elem)
return lst
@requires_nccl()
class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
"""
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
"""
def get_world_trs(self):
return {
"tag": "",
"ranks": list(range(self.world_size)),
"group_size": self.world_size,
}
@property
def world_size(self) -> int:
# hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2
# works around issue with skipif<2 and workers with unpredictable #s gpu
return 2
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_broadcast_inductor(self):
"""
Testing if broadcast works correctly when using inductor
"""
def example(tensor, src, *, tag, ranks, group_size):
res = torch.ops.c10d_functional.broadcast(
tensor, src, tag, ranks, group_size
)
res = torch.ops.c10d_functional.wait_tensor(res)
return res
def compile(func, example_inputs):
graph = make_fx(func)(*example_inputs)
return inductor_compile_fx(graph, example_inputs)
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
example = functools.partial(
example,
**self.get_world_trs(),
)
t = torch.randn(4, 4, device="cuda")
inputs = (t if self.rank == 0 else torch.zeros(4, 4, device="cuda"), 0)
eager_out = example(*inputs)
self.assertTrue(same(t, eager_out))
compiled_func = compile(example, inputs)
compiled_out = compiled_func(*inputs)
self.assertTrue(same(eager_out, compiled_out))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_allreduce_inductor(self):
"""
This is matmul/cat/allreduce is a pattern we aim to optimize.
"""
def matmul_cat_col(a, b, c, d, e, f, *, tag, ranks, group_size):
x = torch.matmul(a, b)
y = torch.matmul(c, d)
z = torch.cat((x, y))
ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
g = torch.matmul(e, f)
ar = torch.ops.c10d_functional.wait_tensor(ar)
out = torch.add(ar, g.repeat(2, 1))
return (out,)
def compile(func, example_inputs):
graph = make_fx(func)(*example_inputs)
return inductor_compile_fx(graph, example_inputs)
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
matmul_cat_col = functools.partial(
matmul_cat_col,
**self.get_world_trs(),
)
inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 6
eager_out = matmul_cat_col(*inputs)
compiled_matmul_cat_col = compile(matmul_cat_col, inputs)
inductor_out = compiled_matmul_cat_col(*inputs)
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_allreduce_inductor_cudagraph_trees(self):
"""
Tests whether cudagraph trees support all_reduce from nccl
"""
import torch.distributed as dist
# dist.all_reduce is an inplace op in eager mode but a functionanlized op in compiled mode.
# so we define eager_func and func separately for the same semantic.
def eager_func(x):
y = x * x
dist.all_reduce(y, op=dist.ReduceOp.SUM)
x = torch.nn.functional.silu(x)
return x * y
def func(x):
y = x * x
y = dist.all_reduce(y, op=dist.ReduceOp.SUM)
x = torch.nn.functional.silu(x)
return x * y
options = {
"triton.cudagraphs": True,
"triton.cudagraph_trees": True,
}
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
compiled_func = torch.compile(
func, backend="inductor", fullgraph=True, options=options, dynamic=None
)
for nelem in [1024, 2048, 4096]:
x = torch.randn(nelem, device="cuda", dtype=torch.bfloat16)
golden_out = eager_func(x)
for _ in range(3):
compiled_out = compiled_func(x)
self.assertEqual(golden_out, compiled_out)
def test_c10d_functional_tagged_pt2_compliant(self):
op = torch.ops._c10d_functional.all_reduce.default
self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
op = torch.ops.c10d_functional.all_reduce.default
self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_eager_allreduce_inductor_wait(self):
def eager_func(a, b, c, d, *, tag, ranks, group_size):
x = torch.matmul(a, b)
y = torch.matmul(c, d)
z = torch.cat((x, y))
ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
return ar
def inductor_func(ar, e, f):
g = torch.matmul(e, f)
ar = torch.ops.c10d_functional.wait_tensor(ar)
out = torch.add(ar, g.repeat(2, 1))
return (out,)
def compile(func, example_inputs):
graph = make_fx(func)(*example_inputs)
return inductor_compile_fx(graph, example_inputs)
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
eager_func = functools.partial(
eager_func,
**self.get_world_trs(),
)
eager_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 4
inductor_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
eager_out = inductor_func(eager_func(*eager_inputs), *inductor_inputs)
compiled_inductor_func = compile(
inductor_func, [eager_func(*eager_inputs)] + list(inductor_inputs)
)
inductor_out = compiled_inductor_func(
eager_func(*eager_inputs), *inductor_inputs
)
print(f"eager_out, {eager_out}")
print(f"inductor_out, {inductor_out}")
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_inductor_allreduce_eager_wait(self):
def inductor_func(a, b, c, d, *, tag, ranks, group_size):
x = torch.matmul(a, b)
y = torch.matmul(c, d)
z = torch.cat((x, y))
ar = torch.ops.c10d_functional.all_reduce(z, "sum", tag, ranks, group_size)
return ar
def eager_func(ar, e, f):
g = torch.matmul(e, f)
ar = torch.ops.c10d_functional.wait_tensor(ar)
out = torch.add(ar, g.repeat(2, 1))
return (out,)
def compile(func, example_inputs):
graph = make_fx(func)(*example_inputs)
return inductor_compile_fx(graph, example_inputs)
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
inductor_func = functools.partial(
inductor_func,
**self.get_world_trs(),
)
inductor_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 4
eager_inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
eager_out = eager_func(inductor_func(*inductor_inputs), *eager_inputs)
compiled_inductor_func = compile(inductor_func, inductor_inputs)
inductor_out = eager_func(
compiled_inductor_func(*inductor_inputs), *eager_inputs
)
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
def test_allreduce_input_buffer_reuse(self):
def func(a, *, tag, ranks, group_size):
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
c = torch.relu(a)
d = torch.matmul(c, c)
e = d + ar
return (e,)
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
inputs = torch.ones(4, 4, device="cuda") + self.rank
compiled = torch.compile(func)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_permute_tensor(self):
def func(tensor, src_dst_pairs, *, tag, ranks, group_size):
return _functional_collectives.permute_tensor(
tensor, src_dst_pairs, ranks, tag
)
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
inputs = (
# rank0: [0., 1.], rank1: [2., 3.]
torch.arange(2, dtype=torch.float32, device="cuda") + 2 * self.rank,
[1, 0],
)
compiled = torch.compile(func)
out = compiled(*inputs, **self.get_world_trs())
correct = func(*inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
# rank0: [2., 3.], rank1: [0., 1.]
expected = torch.arange(2, dtype=torch.float32, device="cuda") + 2 * (
(self.rank - 1 + self.world_size) % self.world_size
)
self.assertEqual(out, expected)
self.assertEqual(correct, expected)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
def test_allgather_output_buffer_reuse(self):
class Model(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.emb = torch.nn.Embedding(4, 4)
def forward(self, x, world_size, tag, ranks, group_size):
y = self.emb(x)
last_dim = y.dim() - 1
res = _functional_collectives.all_gather_tensor(y, 0, ranks, tag)
out = torch.cat(torch.chunk(res, world_size, dim=0), dim=last_dim)
return out
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model = Model().cuda()
model_compiled = torch.compile(model)
inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device="cuda")
out = model_compiled(inp, self.world_size, **self.get_world_trs())
correct = model(inp, self.world_size, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_allgather_contiguous_input(self):
class Model(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.emb = torch.nn.Embedding(4, 4)
def forward(self, x, world_size, tag, ranks, group_size):
y = self.emb(x)
last_dim = y.dim() - 1
y = y.transpose_(0, last_dim).contiguous()
res = _functional_collectives.all_gather_tensor(y, 0, ranks, tag)
out = y.transpose_(0, last_dim).contiguous()
return out
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model = Model().cuda()
model_compiled = torch.compile(model)
inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device="cuda")
out = model_compiled(inp, self.world_size, **self.get_world_trs())
correct = model(inp, self.world_size, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_allgather_into_tensor_inductor(self):
"""
This is matmul/cat/allreduce is a pattern we aim to optimize.
"""
def example(a, b, *, tag, ranks, group_size):
c = torch.matmul(a, b)
ag = torch.ops.c10d_functional.all_gather_into_tensor(
c, tag, ranks, group_size
)
ag = torch.ops.c10d_functional.wait_tensor(ag)
return (ag,)
def compile(func, example_inputs):
graph = make_fx(func)(*example_inputs)
return inductor_compile_fx(graph, example_inputs)
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
example = functools.partial(
example,
**self.get_world_trs(),
)
inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
eager_out = example(*inputs)
compiled_matmul_cat_col = compile(example, inputs)
inductor_out = compiled_matmul_cat_col(*inputs)
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_reduce_scatter_tensor_inductor(self):
def example(a, b, *, tag, ranks, group_size):
c = torch.matmul(a, b)
ag = torch.ops.c10d_functional.reduce_scatter_tensor(
c, "sum", tag, ranks, group_size
)
ag = torch.ops.c10d_functional.wait_tensor(ag)
return (ag,)
def compile(func, example_inputs):
graph = make_fx(func)(*example_inputs)
return inductor_compile_fx(graph, example_inputs)
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
example = functools.partial(
example,
**self.get_world_trs(),
)
inputs = (torch.ones(4, 4, device="cuda") + self.rank,) * 2
eager_out = example(*inputs)
compiled_fn = compile(example, inputs)
inductor_out = compiled_fn(*inputs)
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_all_to_all_single_inductor(self):
def example(
inp,
input_split_sizes_tensor,
output_split_sizes_tensor,
*,
tag,
ranks,
group_size,
):
input_split_sizes = _tolist_with_constrain_as_size(input_split_sizes_tensor)
output_split_sizes = _tolist_with_constrain_as_size(
output_split_sizes_tensor
)
a2a = torch.ops.c10d_functional.all_to_all_single(
inp,
output_split_sizes,
input_split_sizes,
tag,
ranks,
group_size,
)
a2a = torch.ops.c10d_functional.wait_tensor(a2a)
out = a2a / a2a.sum(dim=0)
return out
with _dynamo_dist_per_rank_init(
self.rank, self.world_size
), torch._dynamo.config.patch(
dynamic_shapes=True,
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
):
row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
input_split_sizes_tensor = torch.tensor(
[(i + 1) * (self.rank + 1) for i in range(self.world_size)],
dtype=torch.int64,
)
output_split_sizes_tensor = torch.tensor(
[(i + 1) * (self.rank + 1) for i in range(self.world_size)],
dtype=torch.int64,
)
inputs = (
torch.ones(int(row), 5, device="cuda") * (self.rank + 1),
input_split_sizes_tensor,
output_split_sizes_tensor,
)
trs = self.get_world_trs()
compiled_fn = torch.compile(example, fullgraph=True, dynamic=True)
code = run_and_get_triton_code(compiled_fn, *inputs, **trs)
(
FileCheck()
.check_regex(
"torch.ops._c10d_functional.all_to_all_single.default\\("
"arg\\d+_\\d+, "
"\\[u\\d+, u\\d+\\], "
"\\[u\\d+, u\\d+\\]"
)
.run(code)
)
eager_out = example(*inputs, **trs)
inductor_out = compiled_fn(*inputs, **trs)
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_all_to_all_single_inductor_split_sizes_none(self):
def example(inp, *, tag, ranks, group_size):
a2a = torch.ops.c10d_functional.all_to_all_single(
inp,
None,
None,
tag,
ranks,
group_size,
)
a2a = torch.ops.c10d_functional.wait_tensor(a2a)
out = a2a / a2a.sum(dim=0)
return out
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
inputs = (
torch.ones(self.world_size, self.world_size, device="cuda")
* (self.rank + 1),
)
trs = self.get_world_trs()
compiled_fn = torch.compile(example, fullgraph=True, dynamic=True)
code = run_and_get_triton_code(compiled_fn, *inputs, **trs)
(
FileCheck()
.check_regex(
"torch.ops._c10d_functional.all_to_all_single.default\\("
"arg\\d+_\\d+, "
"\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\], "
"\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\]"
)
.run(code)
)
eager_out = example(*inputs, **trs)
inductor_out = compiled_fn(*inputs, **trs)
self.assertTrue(same(eager_out, inductor_out, tol=0.001))
@instantiate_parametrized_tests
@requires_nccl()
@requires_cuda
class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
"""
Prefer single-proc test runner for basic tests as it is easier to work with.
"""
def get_world_trs(self, world_size=1):
return {
"tag": "",
"ranks": list(range(world_size)),
"group_size": world_size,
}
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(debug=True)
def test_inductor_single_op(self):
def func(inp, *, tag, ranks, group_size):
ar = torch.ops.c10d_functional.all_reduce(
inp, "sum", tag, ranks, group_size
)
ar = torch.ops.c10d_functional.wait_tensor(ar)
return ar
inputs = torch.ones(4, 4, device="cuda")
compiled = torch.compile(func)
out = compiled(inputs, **self.get_world_trs())
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: Make sure we are not unneccessarily copying the outputs of
# wait_tensors before they are returned from the graph.
(
FileCheck()
.check("buf0 = empty_strided")
.check(".run(arg0_1, buf0, 16")
.check("torch.ops._c10d_functional.all_reduce_.default(buf0")
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
.check("return (buf0")
.run(code)
)
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(debug=True)
def test_inductor_steal_buffer(self):
"""
it's ok and optimal if inductor allreduce mutates the buffer of an intermediate
that isn't going to be used again
"""
def func(inp, *, tag, ranks, group_size):
x = inp + 1
ar = torch.ops.c10d_functional.all_reduce(x, "sum", tag, ranks, group_size)
ar = torch.ops.c10d_functional.wait_tensor(ar)
# ensure other is not incorrectly aliasing ar's buffer
other = torch.ones_like(inp) + 22
return ar, other
inputs = torch.ones(4, 4, device="cuda")
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
(
FileCheck()
.check("buf0 = empty_strided")
.check(".run(arg0_1, buf0")
.check("torch.ops._c10d_functional.all_reduce_.default(buf0")
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
.check("buf5 = empty_strided")
.check(".run(buf5, 16")
.check("return (buf0, buf5")
.run(code)
)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
def test_inductor_doesnt_mutate_shared(self):
"""
make sure that an intermediate that's going to be reuse isn't mutated unless copied
"""
def func(inp, *, tag, ranks, group_size):
x = inp + 1
ar = torch.ops.c10d_functional.all_reduce(x, "sum", tag, ranks, group_size)
y = x + 2
ar = torch.ops.c10d_functional.wait_tensor(ar)
# ensure other is not incorrectly aliasing ar's buffer
other = torch.ones_like(inp) + 22
return ar, y, other
inputs = torch.ones(4, 4, device="cuda")
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: Make sure we are not unneccessarily copying the outputs of
# wait_tensors before they are returned from the graph.
(
FileCheck()
.check("buf0 = empty_strided")
.check("buf5 = empty_strided")
.check(".run(arg0_1, buf0, buf5, 16")
.check("torch.ops._c10d_functional.all_reduce_.default(buf0")
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
.check("buf6 = empty_strided")
.check(".run(buf6, 16")
.check("return (buf0, buf5, buf6")
.run(code)
)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
def test_dynamo_trace_allreduce(self):
def func(inp):
ar = _functional_collectives.all_reduce(inp, "sum", "0")
return ar
inputs = torch.ones(4, 4, device="cuda")
counter = CompileCounter()
compiled = torch.compile(func, backend=counter)
out = compiled(inputs)
correct = func(inputs)
self.assertEqual(counter.frame_count, 1)
# should test more precisely, but the 2 is supposed to be (all_reduce, wait)
self.assertEqual(counter.op_count, 2)
self.assertTrue(same(out, correct))
def test_dynamo_trace_all_gather_tensor(self):
def func(inp):
ar = _functional_collectives.all_gather_tensor(inp, 0, "0")
return ar
inputs = torch.ones(4, 4, device="cuda")
counter = CompileCounter()
compiled = torch.compile(func, backend=counter)
out = compiled(inputs)
correct = func(inputs)
self.assertEqual(counter.frame_count, 1)
# should test more precisely, but the 2 is supposed to be (all_gather, wait)
self.assertEqual(counter.op_count, 2)
self.assertTrue(same(out, correct))
def test_dynamo_trace_all_gather_tensor_pg(self):
def func(inp, *, pg):
ar = _functional_collectives.all_gather_tensor(inp, 0, pg)
return ar
inputs = torch.ones(4, 4, device=self.device)
counter = CompileCounter()
compiled = torch.compile(func, backend=counter, fullgraph=True)
out = compiled(inputs, pg=GroupMember.WORLD)
correct = func(inputs, pg=GroupMember.WORLD)
self.assertEqual(counter.frame_count, 1)
# should test more precisely, but the 2 is supposed to be (all_gather, wait)
self.assertEqual(counter.op_count, 2)
self.assertTrue(same(out, correct))
def test_dynamo_rewrite_dist_all_gather(self):
def func(inp, out, *, pg):
torch.distributed.all_gather_into_tensor(
out,
inp,
pg,
)
local_size = [4, 4]
# single-proc test
global_size = local_size
inputs = torch.ones(local_size, device=self.device)
outputs = torch.empty(global_size, device=self.device)
correct_outputs = torch.empty(global_size, device=self.device)
counter = CompileCounter()
compiled = torch.compile(func, backend=counter, fullgraph=True)
compiled(inputs, outputs, pg=GroupMember.WORLD)
func(inputs, correct_outputs, pg=GroupMember.WORLD)
assert counter.frame_count == 1
# should test more precisely, but the 3 is supposed to be (all_gather, wait, copy_)
assert counter.op_count == 3
assert same(outputs, correct_outputs)
def test_dynamo_rewrite_dist_all_gather_list(self):
def func(inp, out, *, pg):
torch.distributed.all_gather(
out,
inp,
pg,
)
local_size = [4, 4]
# single-proc test
global_size = local_size
inputs = torch.ones(local_size, device=self.device)
outputs = [torch.empty(global_size, device=self.device)]
correct_outputs = [torch.empty(global_size, device=self.device)]
counter = CompileCounter()
compiled = torch.compile(func, backend=counter, fullgraph=True)
compiled(inputs, outputs, pg=GroupMember.WORLD)
func(inputs, correct_outputs, pg=GroupMember.WORLD)
assert counter.frame_count == 1
assert same(outputs, correct_outputs)
def test_dynamo_rewrite_dist_all_gather_args_match(self):
# Duplicated most of the structure from test_dynamo_rewrite_dist_all_gather
# except uses kwargs to ensure rewrite has matching arg names
def func(inp, out, *, pg):
torch.distributed.all_gather_into_tensor(
output_tensor=out,
input_tensor=inp,
group=pg,
async_op=False,
)
local_size = [4, 4]
# single-proc test
global_size = local_size
inputs = torch.ones(local_size, device=self.device)
outputs = torch.empty(global_size, device=self.device)
correct_outputs = torch.empty(global_size, device=self.device)
counter = CompileCounter()
compiled = torch.compile(func, backend=counter, fullgraph=True)
compiled(inputs, outputs, pg=GroupMember.WORLD)
func(inputs, correct_outputs, pg=GroupMember.WORLD)
assert counter.frame_count == 1
# should test more precisely, but the 3 is supposed to be (all_gather, wait, copy_)
assert counter.op_count == 3
assert same(outputs, correct_outputs)
def test_dynamo_rewrite_dist_reduce_scatter(self):
def func(inp, out, *, pg):
torch.distributed.reduce_scatter_tensor(
out,
inp,
group=pg,
)
local_size = [4, 4]
# single-proc test
global_size = local_size
inputs = torch.ones(local_size, device=self.device)
outputs = torch.empty(global_size, device=self.device)
correct_outputs = torch.empty(global_size, device=self.device)
counter = CompileCounter()
compiled = torch.compile(func, backend=counter, fullgraph=True)
compiled(inputs, outputs, pg=GroupMember.WORLD)
func(inputs, correct_outputs, pg=GroupMember.WORLD)
assert counter.frame_count == 1
# should test more precisely, but the 3 is supposed to be (reduce_scatter, wait, copy_)
assert counter.op_count == 3
assert same(outputs, correct_outputs)
@parametrize(
"pg_mode",
[
"positional",
"positional_none",
"kwargs",
"kwargs_none",
"unspecified",
],
)
def test_dynamo_rewrite_dist_allreduce(self, pg_mode):
def func(tensor, *args, **kwargs):
torch.distributed.all_reduce(
tensor,
*args,
**kwargs,
)
counter = CompileCounter()
compiled = torch.compile(func, backend=counter, fullgraph=True)
args = []
kwargs = {}
if pg_mode == "positional":
args.append(torch.distributed.ReduceOp.MAX)
args.append(GroupMember.WORLD)
elif pg_mode == "positional_none":
args.append(torch.distributed.ReduceOp.MAX)
args.append(None)
elif pg_mode == "kwargs":
kwargs["group"] = GroupMember.WORLD
elif pg_mode == "kwargs_none":
kwargs["group"] = None
else:
assert pg_mode == "unspecified"
inputs_compiled = torch.ones(2, device=self.device)
inputs_eager = torch.ones(2, device=self.device)
compiled(inputs_compiled, *args, **kwargs)
func(inputs_eager, *args, **kwargs)
assert counter.frame_count == 1
# should test more precisely, but the 3 is supposed to be (all_reduce, wait, copy_)
assert counter.op_count == 3
assert same(inputs_compiled, inputs_eager)
def test_dynamo_rewrite_dist_all_to_all_single(self):
def func(output, input, pg):
torch.distributed.all_to_all_single(output, input, group=pg)
counter = CompileCounter()
compiled = torch.compile(func, backend=counter, fullgraph=True)
input_compiled = torch.ones(2, device=self.device)
input_eager = torch.ones(2, device=self.device)
output_compiled = torch.empty(2, device=self.device)
output_eager = torch.empty(2, device=self.device)
compiled(output_compiled, input_compiled, GroupMember.WORLD)
func(output_eager, input_eager, GroupMember.WORLD)
assert counter.frame_count == 1
assert same(output_compiled, output_eager)
@parametrize(
"reduce_op",
[
torch.distributed.ReduceOp.SUM,
torch.distributed.ReduceOp.AVG,
torch.distributed.ReduceOp.PRODUCT,
torch.distributed.ReduceOp.MIN,
torch.distributed.ReduceOp.MAX,
],
)
def test_dynamo_rewrite_dist_allreduce_reduce_op(self, reduce_op):
from torch.distributed._functional_collectives import REDUCE_OP_TO_STR
def verify_rewrite(gm, _):
ar_nodes = []
for node in gm.graph.nodes:
if node.target in [
torch.ops.c10d_functional.all_reduce,
torch.ops._c10d_functional.all_reduce,
]:
ar_nodes.append(node)
self.assertEqual(len(ar_nodes), 1)
reduce_op_str = ar_nodes[0].args[1]
self.assertEqual(REDUCE_OP_TO_STR[reduce_op], reduce_op_str)
return gm
compiled = torch.compile(
torch.distributed.all_reduce,
backend=verify_rewrite,
fullgraph=True,
)
inputs = (
torch.ones(2, device=self.device),
reduce_op,
GroupMember.WORLD,
)
compiled(*inputs)
@parametrize(
"source",
[
"GroupMember.WORLD",
"group.WORLD",
"_get_default_group",
],
)
def test_dynamo_get_world_group(self, source):
def func(tensor):
if source == "GroupMember.WORLD":
group = torch.distributed.GroupMember.WORLD
elif source == "group.WORLD":
group = torch.distributed.group.WORLD
else:
assert source == "_get_default_group"
group = torch.distributed.distributed_c10d._get_default_group()
torch.distributed.all_reduce(
tensor,
group=group,
)
def verify(gm, _):
ar_nodes = []
for node in gm.graph.nodes:
if node.target in [
torch.ops.c10d_functional.all_reduce,
torch.ops._c10d_functional.all_reduce,
]:
ar_nodes.append(node)
self.assertEqual(len(ar_nodes), 1)
return gm
compiled = torch.compile(func, backend=verify, fullgraph=True)
input = torch.ones(2, device=self.device)
compiled(input)
def test_dynamo_support_collective_op_with_async_op_False(self):
def func(inp, out, *, pg):
# user explicitly set the attribute `async_op` to False,
# there should be no graph break
torch.distributed.reduce_scatter_tensor(out, inp, group=pg, async_op=False)
local_size = [4, 4]
# single-proc test
global_size = local_size
inputs = torch.ones(local_size, device=self.device)
outputs = torch.empty(global_size, device=self.device)
correct_outputs = torch.empty(global_size, device=self.device)
counter = CompileCounter()
compiled = torch.compile(func, backend=counter)
compiled(inputs, outputs, pg=GroupMember.WORLD)
func(inputs, correct_outputs, pg=GroupMember.WORLD)
assert counter.frame_count == 1
assert counter.op_count == 3
assert same(outputs, correct_outputs)
def test_dynamo_graphbreaks_unsupported_async_op(self):
def func(inp, out, *, pg):
work = torch.distributed.reduce_scatter_tensor(
out, inp, group=pg, async_op=True
)
work.wait()
local_size = [4, 4]
# single-proc test
global_size = local_size
inputs = torch.ones(local_size, device=self.device)
outputs = torch.empty(global_size, device=self.device)
correct_outputs = torch.empty(global_size, device=self.device)
counter = CompileCounter()
compiled = torch.compile(func, backend=counter)
compiled(inputs, outputs, pg=GroupMember.WORLD)
func(inputs, correct_outputs, pg=GroupMember.WORLD)
assert counter.frame_count == 0
assert counter.op_count == 0
assert same(outputs, correct_outputs)
def test_dynamo_pg_var(self):
def func(inp, *, pg):
x = pg.rank() + 1 % pg.size()
return inp + x
local_size = [4, 4]
inputs = torch.ones(local_size, device=self.device)
correct_outputs = torch.empty(local_size, device=self.device)
counter = CompileCounter()
compiled = torch.compile(func, backend=counter, fullgraph=True)
outputs = compiled(inputs, pg=GroupMember.WORLD)
correct_outputs = func(inputs, pg=GroupMember.WORLD)
assert counter.frame_count == 1
assert counter.op_count == 1
assert same(outputs, correct_outputs)
def test_dynamo_trace_reduce_scatter_tensor(self):
def func(inp):
ar = _functional_collectives.reduce_scatter_tensor(inp, "sum", 0, "0")
return ar
inputs = torch.ones(4, 4, device="cuda")
counter = CompileCounter()
compiled = torch.compile(func, backend=counter)
out = compiled(inputs)
correct = func(inputs)
self.assertEqual(counter.frame_count, 1)
# should test more precisely, but the 2 is supposed to be (reduce_scatter, wait)
self.assertEqual(counter.op_count, 2)
self.assertTrue(same(out, correct))
def test_dynamo_trace_allgather_coalesced(self):
def func(inp, *, tag, ranks, group_size):
ar = torch.ops.c10d_functional.all_gather_into_tensor_coalesced(
inp, tag, ranks, group_size
)
return ar
inputs = [torch.ones(4, 4, device="cuda"), torch.ones(6, 6, device="cuda")]
counter = CompileCounter()
compiled = torch.compile(func, backend=counter)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
assert counter.frame_count == 1
assert counter.op_count == 3 # It generates 2 getattr to unpack the array
assert same(out, correct)
def test_backwards(self):
"""
It's probably not that common to need backwards support for collectives.
However, I wanted to at least see if it was possible to support it as a design goal.
"""
def func(inp):
ar = _functional_collectives.all_reduce(inp, "sum", "0")
return ar
input = torch.ones(4, 4, device="cuda", requires_grad=True)
# TODO implement backwards
with self.assertRaisesRegex(
RuntimeError,
"element 0 of tensors does not require grad and does not have a grad_fn",
):
compiled = torch.compile(
func, backend="aot_eager"
) # inductor bug with single-op allreduce graph
out = compiled(input)
out.sum().backward()
correct_input = input.clone().detach().requires_grad_()
correct = func(correct_input)
correct.sum().backward()
self.assertTrue(same(out, correct))
self.assertTrue(same(input.grad, correct_input.grad))
def test_meta(self):
x = torch.rand((2, 3, 4), device="meta")
out = torch.ops.c10d_functional.all_reduce(x, "sum", **self.get_world_trs())
self.assertEqual(x.size(), out.size())
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
def test_inductor_all_gather_coalesced(self):
"""
make sure that an intermediate that's going to be reuse isn't mutated unless copied
"""
def func(inp, *, tag, ranks, group_size):
x = inp + 1
tensor_list = torch.ops.c10d_functional.all_gather_into_tensor_coalesced(
[x, inp], tag, ranks, group_size
)
y = x + 2
ar0 = torch.ops.c10d_functional.wait_tensor(tensor_list[0])
ar1 = torch.ops.c10d_functional.wait_tensor(tensor_list[1])
# ensure other is not incorrectly aliasing ar's buffer
other = torch.ones_like(inp) + 22
return ar0, y, other, ar1
inputs = torch.ones(4, 4, device="cuda")
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: Make sure we are not unneccessarily copying the outputs of
# wait_tensors before they are returned from the graph.
(
FileCheck()
.check("buf0 = empty_strided")
.check("buf6 = empty_strided")
.check(".run(arg0_1, buf0, buf6, 16")
.check(
"buf1 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced.default([buf0, arg0_1]"
)
.check("buf2 = buf1[0]")
.check("buf3 = buf1[1]")
.check("torch.ops._c10d_functional.wait_tensor.default(buf2")
.check("buf7 = buf0; del buf0 # reuse")
.check(".run(buf7, 16")
.check("torch.ops._c10d_functional.wait_tensor.default(buf3")
.check("return (buf2, buf6, buf7, buf3")
.run(code)
)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch({"debug": True, "triton.descriptive_names": False})
def test_inductor_reduce_scatter_coalesced(self):
"""
make sure that an intermediate that's going to be reuse isn't mutated unless copied
"""
def func(inp, *, tag, ranks, group_size):
x = inp + 1
tensor_list = torch.ops.c10d_functional.reduce_scatter_tensor_coalesced(
[x, inp], "sum", tag, ranks, group_size
)
y = x + 2
ar0 = torch.ops.c10d_functional.wait_tensor(tensor_list[0])
ar1 = torch.ops.c10d_functional.wait_tensor(tensor_list[1])
# ensure other is not incorrectly aliasing ar's buffer
other = torch.ones_like(inp) + 22
return ar0, y, other, ar1
inputs = torch.ones(4, 4, device="cuda")
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: The first return value should be the output of the first wait_tensor.
# We want to make sure no unneccessary copy is made.
(
FileCheck()
.check("buf0 = empty_strided")
.check("buf6 = empty_strided")
.check(".run(arg0_1, buf0, buf6, 16")
.check(
"buf1 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced.default([buf0, arg0_1]"
)
.check("buf2 = buf1[0]")
.check("buf3 = buf1[1]")
.check("torch.ops._c10d_functional.wait_tensor.default(buf2")
.check("buf7 = buf0; del buf0 # reuse")
.check(".run(buf7, 16")
.check("torch.ops._c10d_functional.wait_tensor.default(buf3")
.check("return (buf2, buf6, buf7, buf3")
.run(code)
)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
assert same(out, correct), f"{out} va {correct}"
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()