blob: d5b9f070b403e153596e08a80ba6290997427b33 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import copy
import functools
import os
import random
import unittest
from unittest.mock import patch
import numpy as np
import torch
import torch._dynamo
from torch._dynamo.optimizations.distributed import DDPOptimizer
import torch._dynamo.test_case
import torch.distributed as dist
from contextlib import contextmanager
from torch import nn
from torch._dynamo import config
from torch._dynamo.utils import same
from torch._dynamo.testing import collect_results
from torch._inductor.utils import has_triton
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
import_transformers_or_skip,
skip_if_lt_x_gpu,
requires_nccl
)
import torch._dynamo.logging
def reset_rng_state():
torch.manual_seed(1337)
random.seed(1337)
np.random.seed(1337)
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
class ToyModel(nn.Module):
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5):
super().__init__()
self.net = nn.Sequential(
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
)
def forward(self, inputs):
return self.net(inputs)
def get_model(device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5):
m = ToyModel(in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat).to(device)
m.apply(init_weights)
inputs = torch.rand(bsz, in_feat).to(device)
outputs = m(inputs)
return m, inputs, outputs
def get_custom_model(device):
class MyCustomLinear(torch.nn.Module):
def __init__(self):
super(MyCustomLinear, self).__init__()
self.weight = nn.Parameter(torch.randn(512, 512))
def forward(self, x):
return torch.mm(x, self.weight.t())
class MyLinear(torch.nn.Module):
def __init__(self):
super(MyLinear, self).__init__()
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
return self.linear(x)
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
mods = [
(MyLinear(), torch.nn.ReLU()),
# sandwitch the custom in the middle so it comes before and after
(MyCustomLinear(), torch.nn.ReLU()),
(MyLinear(), torch.nn.ReLU()),
]
self.seq = torch.nn.Sequential(*[x for items in mods for x in items])
def forward(self, x):
return self.seq(x)
m = MyModule().to(device)
m.apply(init_weights)
inputs = torch.rand((512, 512)).to(device)
correct_outputs = m(inputs)
return m, inputs, correct_outputs
def get_hf_bert(rank):
# Note: use @import_transformers_or_skip on your test case if you use this
# in a multiprocessing test
try:
from transformers import BertConfig, AutoModelForMaskedLM
except ImportError:
raise unittest.SkipTest("Unable to import transformers")
batch_size, max_length, config, device = 4, 512, BertConfig(), f"cuda:{rank}"
model = AutoModelForMaskedLM.from_config(config).to(device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device)
decoder_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device)
inputs = {'input_ids': input_ids, 'labels': decoder_ids}
model.train()
return model, inputs
class CheckSplitsCompiler:
def __init__(self):
self.compiler_called = 0
def compile_fn(self, gm, example_inputs):
self.compiler_called += 1
return gm
@contextmanager
def _per_rank_init(rank, world_size):
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
# Just manually implement the most important part of the dynamo behavior to reset/clear.
torch.cuda.set_device(rank)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '6789'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
yield
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
dist.destroy_process_group()
# This simulates DDP, but it doesn't actually do any process communication;
# it just has enough properties so that the dynamo distributed optimization is
# able to optimize. Feel free to simulate more properties as necessary. The
# other important thing is patching _active_ddp_module, which is what actually
# triggers DDP optimization
class FakeDDP(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
bucket_cap_mb = 25
self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
@contextmanager
def _inside_ddp_forward(self):
DDP._active_ddp_module = self
try:
yield
except Exception:
raise
finally:
DDP._active_ddp_module = None
def forward(self, *inputs, **kwargs):
with self._inside_ddp_forward():
return self.module.forward(*inputs, **kwargs)
def run_hf_bert_ddp(self, model, inputs, backend):
reset_rng_state()
correct_outputs = model(**inputs)
correct_loss = correct_outputs.loss
correct_loss.backward()
reset_rng_state()
opt_model = torch._dynamo.optimize(backend)(model)
opt_outputs = opt_model(**inputs)
opt_loss = opt_outputs.loss
opt_loss.backward()
inputs_flat = [inputs[k] for k in inputs]
correct_results = collect_results(model, correct_outputs.logits, correct_loss, inputs_flat)
opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat)
self.assertTrue(same(correct_results, opt_results))
class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(config, "optimize_ddp", True)
@patch.object(torch._inductor.config, "fallback_random", True)
def test_hf_bert_ddp_inductor(self):
model, inputs = get_hf_bert(0)
model = FakeDDP(model)
run_hf_bert_ddp(self, model, inputs, "inductor")
@patch.object(config, "optimize_ddp", True)
def test_hf_bert_ddp_aot_eager(self):
model, inputs = get_hf_bert(0)
model = FakeDDP(model)
run_hf_bert_ddp(self, model, inputs, "aot_eager")
# Are these tests failing? Check and see if TestFakeDistributedSingleProc has a
# single process version; if it's just a problem in the Dynamo distributed
# optimizer, you should be able to repro it single process!
@requires_nccl()
class TestDistributedMultiProc(MultiProcessTestCase):
def setUp(self):
super(TestDistributedMultiProc, self).setUp()
self._spawn_processes()
def tearDown(self):
super(TestDistributedMultiProc, self).tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
@property
def world_size(self) -> int:
return torch.cuda.device_count()
@classmethod
def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None:
# Don't enable DDP + ReplicatedTensor, as that breaks Dynamo+DDP
# TODO(whc) why is ReplicatedTensor defaulted=True in MultiProcessTestCase, and should we support it?
# from torch.nn.parallel._replicated_tensor_ddp_utils import _set_ddp_with_replicated_tensor
# _set_ddp_with_replicated_tensor(True)
# The rest is copypasta from MultiProcessTestCase._run
self = cls(test_name)
self.rank = rank
self.file_name = file_name
self.run_test(test_name, parent_pipe)
@skip_if_lt_x_gpu(2)
@patch.object(config, "optimize_ddp", False)
def test_ddp_baseline_aot_eager_multiprocess(self):
with _per_rank_init(self.rank, self.world_size):
self.assertFalse(config.optimize_ddp)
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
m = DDP(m, device_ids=[self.rank])
m = torch._dynamo.optimize("aot_eager")(m)
outputs = m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(config, "optimize_ddp", True)
@patch.object(torch._inductor.config, "fallback_random", True)
def test_hf_bert_ddp_inductor(self):
with _per_rank_init(self.rank, self.world_size):
model, inputs = get_hf_bert(self.rank)
model = DDP(model)
run_hf_bert_ddp(self, model, inputs, "inductor")
@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@patch.object(config, "optimize_ddp", True)
def test_hf_bert_ddp_aot_eager(self):
with _per_rank_init(self.rank, self.world_size):
model, inputs = get_hf_bert(self.rank)
model = DDP(model)
run_hf_bert_ddp(self, model, inputs, "aot_eager")
@skip_if_lt_x_gpu(1)
def test_fsdp_aot_eager(self):
with _per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
# Test with recursive wrapping, nested FSDP around each Linear
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
fsdp_m = FSDP(
m,
auto_wrap_policy=functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear, )
),
use_orig_params=True
)
fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@skip_if_lt_x_gpu(1)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_fsdp_inductor(self):
with _per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
# Test with recursive wrapping, nested FSDP around each Linear
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
fsdp_m = FSDP(
m,
auto_wrap_policy=functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear, )
),
use_orig_params=True
)
fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@import_transformers_or_skip()
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
# TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert
@patch.object(torch._inductor.config.triton, "cudagraphs", False)
@patch.object(torch._inductor.config, "fallback_random", True)
def test_hf_bert_fsdp(self):
from transformers.models.bert.modeling_bert import BertLayer
def apply_fsdp(model, wrap_policy):
model = FSDP(
copy.deepcopy(model),
auto_wrap_policy=wrap_policy,
use_orig_params=True
)
return model
with _per_rank_init(self.rank, self.world_size):
for (wrap_policy, test_instance) in (
(
None,
"FSDP without recursive wrapping"
),
(
functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=(BertLayer, )
),
"FSDP with recursive wrapping BertLayer instances"
)
):
print(f"Running hf_bert test for {test_instance}")
model, inputs = get_hf_bert(self.rank)
reset_rng_state()
eager_model = apply_fsdp(model, wrap_policy)
correct_outputs = eager_model(**inputs)
correct_loss = correct_outputs.loss
correct_loss.backward()
reset_rng_state()
opt_model = apply_fsdp(model, wrap_policy)
opt_model = torch._dynamo.optimize("inductor")(opt_model)
opt_outputs = opt_model(**inputs)
opt_loss = opt_outputs.loss
opt_loss.backward()
inputs_flat = [inputs[k] for k in inputs]
correct_results = collect_results(eager_model, correct_outputs.logits, correct_loss, inputs_flat)
opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat)
self.assertTrue(same(correct_results, opt_results))
@requires_nccl()
class TestDistributed(torch._dynamo.test_case.TestCase):
"""
Test harness initializes dist process group
"""
@classmethod
def setUpClass(cls):
super().setUpClass()
# _exit_stack is set up in TestCase
cls._exit_stack.enter_context(
patch.dict(
os.environ,
{
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12355",
},
)
)
cls.rank = 0
cls.device = f"cuda:{cls.rank}"
cls.device_ids = None if "cuda" in cls.device else [cls.rank]
dist.init_process_group("nccl", rank=cls.rank, world_size=1)
@classmethod
def tearDownClass(cls):
dist.destroy_process_group()
super().tearDownClass()
def get_model(self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5):
m = ToyModel(in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat).to(self.device)
m.apply(init_weights)
inputs = torch.rand(bsz, in_feat).to(self.device)
outputs = m(inputs)
return m, inputs, outputs
@patch.object(config, "optimize_ddp", False)
def test_ddp_baseline_aot_eager(self):
from torch.nn.parallel import DistributedDataParallel as DDP
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids)
ddp_m = torch._dynamo.optimize("aot_eager")(ddp_m)
outputs = ddp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(config, "optimize_ddp", False)
def test_ddp_baseline_inductor(self):
from torch.nn.parallel import DistributedDataParallel as DDP
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids)
ddp_m = torch._dynamo.optimize("inductor")(ddp_m)
outputs = ddp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@patch.object(config, "optimize_ddp", True)
def test_graph_split(self):
"""
Just ensures that the appropriate number of splits happen (based on
bucket size and model parameters) - verifies the number of times
the user-provided compiler is called by the DDPOptimizer which is
doing the graph splitting
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
check_splits_compiler = CheckSplitsCompiler()
@torch._dynamo.optimize(check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 3)
@patch.object(config, "optimize_ddp", True)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_graph_split_inductor(self):
"""
Same as above, but using inductor backend.
We observed issues with inductor/fx interface in the past.
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
@torch._dynamo.optimize("inductor")
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
@patch.object(config, "optimize_ddp", True)
def test_no_split(self):
"""
Ensures the DDPOptimizer returns a correct, compiled module without
introducing graph splits. (Based on model parmeters fitting in the bucket)
"""
# DDP will always do a 'first bucket' with a really small size; so only a tiny model will escape this
m, inputs, correct_outputs = self.get_model(hidden_feat=5)
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250)
check_splits_compiler = CheckSplitsCompiler()
@torch._dynamo.optimize(check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 1)
@patch.object(config, "optimize_ddp", True)
def test_aot_autograd(self):
"""
Explicitly check AotAutograd family of compilers work,
since they require example inputs propagated between graph splits.
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
@torch._dynamo.optimize("aot_eager")
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
opt_outputs.sum().backward()
self.assertTrue(same(correct_outputs, opt_outputs))
@patch.object(config, "optimize_ddp", True)
def test_custom_layer(self):
"""
Just ensures that the appropriate number of splits happen (based on
bucket size and model parameters) - verifies the number of times
the user-provided compiler is called by the DDPOptimizer which is
doing the graph splitting
"""
m, inputs, correct_outputs = get_custom_model(self.device)
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1)
check_splits_compiler = CheckSplitsCompiler()
@torch._dynamo.optimize(check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 3)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_empty_graph_inductor(self):
def fn():
get_world_size = torch.distributed.distributed_c10d.get_world_size()
return (get_world_size,)
opt_fn = torch._dynamo.optimize("inductor")(fn)
res = None
try:
res = opt_fn()[0]
except Exception:
pass
self.assertEqual(res, 1)
@patch.object(config, "optimize_ddp", False)
def test_ignored_parameters(self):
"""
Verifies ddp graph-split logic ignores parameters marked to ignore on DDP module.
Hooks up graph-split optimizer manually so it can peek at internal state.
"""
m, inputs, correct_outputs = get_custom_model(self.device)
parameters_to_ignore = ["seq.2.weight", "seq.4.linear.bias"]
DDP._set_params_and_buffers_to_ignore_for_model(m, parameters_to_ignore)
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
parameter_ids_to_ignore = [
id(ddp_m.module.get_parameter(p))
for p in ddp_m.parameters_to_ignore
]
check_splits_compiler = CheckSplitsCompiler()
ddp_optimizer = DDPOptimizer(
bucket_bytes_cap=ddp_m.bucket_bytes_cap,
backend_compile_fn=check_splits_compiler.compile_fn
)
@torch._dynamo.optimize(ddp_optimizer.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 2)
for b in ddp_optimizer.buckets:
for p_id in b.param_ids:
self.assertFalse(p_id in parameter_ids_to_ignore)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()