| # Owner(s): ["oncall: distributed"] |
| |
| from copy import deepcopy |
| from functools import wraps |
| from typing import Any, List |
| |
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| import torch.fx as fx |
| import torch.nn as nn |
| from torch.distributed._spmd.api import ( |
| compile, |
| COMPILED_OBJECT_KEY, |
| Override, |
| Schema, |
| SPMD, |
| ) |
| from torch.distributed._spmd.comm_tensor import CommTensor |
| from torch.distributed._tensor import DeviceMesh, Replicate |
| from torch.distributed._tensor.op_schema import OpSchema, OutputSharding |
| from torch.distributed._tensor.ops.utils import register_prop_rule |
| from torch.distributed._tensor.placement_types import DTensorSpec |
| from torch.distributed.distributed_c10d import get_global_rank, get_world_size |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.testing._internal.common_distributed import skip_if_lt_x_gpu |
| from torch.testing._internal.common_utils import run_tests |
| from torch.testing._internal.distributed._tensor.common_dtensor import ( |
| DTensorTestBase, |
| with_comms as base_with_comms, |
| ) |
| |
| |
| def with_comms(func): |
| @base_with_comms |
| @wraps(func) |
| def wrapper(self, *args, **kwargs): |
| # make sure we set different random seeds for each rank |
| # otherwise we dont need DDP / SPMD |
| # (we would have the same parameters and inputs everywhere) |
| torch.manual_seed(self.rank) |
| return func(self, *args, **kwargs) |
| |
| return wrapper |
| |
| |
| class TraceDeviceMeshTestBase: |
| def _test_tracing_all_reduce_nd(self, mesh_tensor): |
| mesh = DeviceMesh(self.device_type, mesh_tensor) |
| local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank |
| |
| # check all dim groups |
| dim_to_subgroups = mesh.get_dim_groups() |
| for dim, dim_group in enumerate(dim_to_subgroups): |
| dim_group_size = get_world_size(dim_group) |
| global_ranks = [ |
| get_global_rank(dim_group, i) for i in range(dim_group_size) |
| ] |
| |
| def fn(tensor: torch.Tensor): |
| tensor = mesh.all_reduce(tensor, mesh_dim=dim) |
| # multiply with 1 to trigger wait on read during tracing. |
| return tensor * 1 |
| |
| # use a local_tensor + 1 for tracing to make sure that we are not |
| # simply replaying recorded tensor value |
| traced_fn = make_fx(fn)(local_tensor + 1) |
| |
| # execute traced DeviceMesh communication |
| reduced_tensor = traced_fn(local_tensor.clone()) |
| res_num = sum(global_ranks) |
| self.assertEqual(reduced_tensor, torch.ones(3, 3) * res_num) |
| |
| def _test_broadcast_nd(self, mesh_tensor): |
| mesh = DeviceMesh(self.device_type, mesh_tensor) |
| |
| # check all dim groups |
| dim_to_subgroups = mesh.get_dim_groups() |
| for dim, dim_group in enumerate(dim_to_subgroups): |
| dim_group_size = get_world_size(dim_group) |
| global_ranks = [ |
| get_global_rank(dim_group, i) for i in range(dim_group_size) |
| ] |
| |
| def fn(tensor: torch.Tensor): |
| received_tensor = CommTensor(tensor.clone()) |
| mesh.broadcast(received_tensor, mesh_dim=dim) |
| # multiply with 1 to trigger wait on read during tracing. |
| return received_tensor * 1 |
| |
| local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank |
| # use a local_tensor + 1 for tracing to make sure that we are not |
| # simply replaying recorded tensor value |
| traced_fn = make_fx(fn)(local_tensor + 1) |
| |
| # execute traced DeviceMesh communication |
| received_tensor = traced_fn(local_tensor) |
| res_num = global_ranks[0] |
| self.assertEqual(received_tensor, torch.ones(3, 3) * res_num) |
| |
| def _test_scatter_nd(self, mesh_tensor): |
| mesh = DeviceMesh(self.device_type, mesh_tensor) |
| |
| # check all dim groups |
| dim_to_subgroups = mesh.get_dim_groups() |
| for dim, dim_group in enumerate(dim_to_subgroups): |
| dim_group_size = get_world_size(dim_group) |
| global_ranks = [ |
| get_global_rank(dim_group, i) for i in range(dim_group_size) |
| ] |
| scattered_tensors = [ |
| torch.ones(3, 3, device=self.device_type) * global_rank |
| for global_rank in global_ranks |
| ] |
| |
| def fn(to_receive: torch.Tensor, to_scatter: List[torch.Tensor]): |
| to_scatter = [CommTensor(t) for t in to_scatter] |
| to_receive = CommTensor(to_receive) |
| mesh.scatter(to_receive, to_scatter, mesh_dim=dim) |
| # multiply with 1 to trigger wait on read during tracing. |
| return to_receive * 1 |
| |
| # use a local_tensor + 1 for tracing to make sure that we are not |
| # simply replaying recorded tensor value |
| to_receive = torch.empty_like(scattered_tensors[mesh.get_coordinate()[dim]]) |
| traced_fn = make_fx(fn)(to_receive, [t + 1 for t in scattered_tensors]) |
| |
| received_tensor = traced_fn(to_receive, scattered_tensors) |
| self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank) |
| |
| def _test_all_gather_nd(self, mesh_tensor): |
| mesh = DeviceMesh(self.device_type, mesh_tensor) |
| # each rank have its own tensor, all_gather gives a big tensor |
| local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank |
| |
| dim_to_subgroups = mesh.get_dim_groups() |
| for dim, dim_group in enumerate(dim_to_subgroups): |
| dim_group_size = get_world_size(dim_group) |
| global_ranks = [ |
| get_global_rank(dim_group, i) for i in range(dim_group_size) |
| ] |
| |
| def fn(tensor: torch.Tensor): |
| big_tensor = mesh.all_gather(tensor, mesh_dim=dim) |
| return list(torch.chunk(big_tensor, dim_group_size)) |
| |
| # use a local_tensor + 1 for tracing to make sure that we are not |
| # simply replaying recorded tensor value |
| traced_fn = make_fx(fn)(local_tensor + 1) |
| gathered_list = traced_fn(local_tensor) |
| |
| self.assertEqual(len(gathered_list), dim_group_size) |
| for idx, gathered_tensor in enumerate(gathered_list): |
| self.assertEqual(gathered_tensor, torch.ones(3, 3) * global_ranks[idx]) |
| |
| |
| class TraceDeviceMesh3DTest(DTensorTestBase, TraceDeviceMeshTestBase): |
| @property |
| def world_size(self): |
| return 8 |
| |
| @with_comms |
| def test_tracing_all_reduce_nd(self): |
| self._test_tracing_all_reduce_nd(torch.arange(8).reshape(2, 2, 2)) |
| |
| @with_comms |
| def test_broadcast_nd(self): |
| self._test_broadcast_nd(torch.arange(8).reshape(2, 2, 2)) |
| |
| @with_comms |
| def test_scatter_nd(self): |
| self._test_scatter_nd(torch.arange(8).reshape(2, 2, 2)) |
| |
| @with_comms |
| def test_all_gather_nd(self): |
| self._test_all_gather_nd(torch.arange(8).reshape(2, 2, 2)) |
| |
| |
| class TraceDeviceMesh2DTest(DTensorTestBase, TraceDeviceMeshTestBase): |
| @property |
| def world_size(self): |
| return 4 |
| |
| @with_comms |
| def test_tracing_all_reduce_nd(self): |
| self._test_tracing_all_reduce_nd(torch.arange(4).reshape(2, 2)) |
| |
| @with_comms |
| def test_broadcast_nd(self): |
| self._test_broadcast_nd(torch.arange(4).reshape(2, 2)) |
| |
| @with_comms |
| def test_scatter_nd(self): |
| self._test_scatter_nd(torch.arange(4).reshape(2, 2)) |
| |
| @with_comms |
| def test_all_gather_nd(self): |
| self._test_all_gather_nd(torch.arange(4).reshape(2, 2)) |
| |
| |
| class TraceModuleTest(DTensorTestBase): |
| @property |
| def world_size(self): |
| return 2 |
| |
| def _test_trace_replicate(self, model: nn.Module, x, *args, **kwargs): |
| # if x.device.type == "cuda": |
| ddp = DDP(deepcopy(model)) |
| spmd = SPMD( |
| deepcopy(model), |
| schema=Schema( |
| mesh=DeviceMesh(self.device_type, torch.arange(self.world_size)), |
| placements=[Replicate()], |
| ), |
| input_schemas=kwargs["inp_schemas"] if "inp_schemas" in kwargs else None, |
| ) |
| if "inp_schemas" in kwargs: |
| del kwargs["inp_schemas"] |
| only_fw = False |
| if "only_fw" in kwargs: |
| only_fw = kwargs["only_fw"] |
| del kwargs["only_fw"] |
| if only_fw: |
| output_ddp = ddp(x, *args, **kwargs) |
| output_spmd = spmd(x, *args, **kwargs) |
| self.assertTrue(output_ddp.size(), output_spmd.size()) |
| return |
| ddp(x, *args, **kwargs).sum().backward() |
| spmd(x, *args, **kwargs).sum().backward() |
| for p1, p2 in zip(ddp.parameters(), spmd.parameters()): |
| # DDP divides gradients by world size to compute average, but |
| # _Partial tensor shouldn't do that automatically. Hence explicitly |
| # do division here. |
| self.assertTrue( |
| p1.grad.allclose(p2.grad / self.world_size) or p1.grad.allclose(p2.grad) |
| ) |
| |
| @with_comms |
| def test_torch_cat(self): |
| x = torch.rand((2, 4)).to(self.device_type) |
| |
| class Model(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.nn.Parameter(torch.rand((2, 4))) |
| |
| def forward(self, x): |
| # TODO(anj): Using self.w and ignoring x results in an allgather call |
| # that we have not yet supported. |
| return torch.cat((self.w, self.w), 0) |
| |
| model = Model().to(self.device_type) |
| inp_kwargs = {} |
| inp_kwargs["inp_schemas"] = [ |
| Schema( |
| mesh=DeviceMesh(self.device_type, torch.arange(self.world_size)), |
| placements=[Replicate()], |
| ) |
| ] |
| self._test_trace_replicate( |
| Model().to(self.device_type), |
| torch.rand((2, 4)).to(self.device_type), |
| **inp_kwargs, |
| ) |
| |
| @with_comms |
| def test_layer_norm_fw(self): |
| # This test is for get_item support. layer_norm contains |
| # tuples in its output which means we need to support get_item. |
| input_dims = [] |
| |
| input = np.random.randn(4, 5).astype(np.float32) |
| model = nn.LayerNorm(input.shape[1:]).to(self.device_type) |
| pt_input = torch.tensor(input, dtype=torch.float).to(self.device_type) |
| self._test_trace_replicate(model, pt_input) |
| |
| @with_comms |
| def test_baked_in_shape(self): |
| class LCE(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| torch.manual_seed(5) |
| self.w = torch.nn.Parameter(torch.rand((5, 10))) |
| self.b = torch.nn.Parameter(torch.rand((5))) |
| |
| def forward(self, x, *args, **kwargs): |
| # the code below will bake in the shape of x_t as arguments to expand |
| x_t = x.permute(0, 2, 1) |
| y_t = kwargs["dict_test"]["value"].expand(x_t.shape) + args[0][ |
| 0 |
| ].expand(x_t.shape) |
| # code below triggers an "expand" with shape baked in. |
| return torch.nn.functional.linear(y_t, self.w, self.b) |
| |
| model = LCE().to(self.device_type) |
| x = torch.randn(2, 10, 80).to(self.device_type) |
| y = torch.randn(2, 80, 10).to(self.device_type) |
| z = torch.randn(2, 80, 10).to(self.device_type) |
| self._test_trace_replicate(model, x, [y], dict_test={"value": z}) |
| |
| @with_comms |
| def test_sequential(self): |
| model = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)]).to( |
| self.device_type |
| ) |
| x = torch.randn(2, 10).to(self.device_type) |
| self._test_trace_replicate(model, x) |
| |
| @with_comms |
| def test_parallel(self): |
| class Model(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.module_list = nn.ModuleList([nn.Linear(10, 10) for _ in range(2)]) |
| |
| def forward(self, x): |
| return sum([m(x) for m in self.module_list]) |
| |
| model = Model().to(self.device_type) |
| x = torch.randn(2, 10).to(self.device_type) |
| self._test_trace_replicate(model, x) |
| |
| @with_comms |
| def test_hybrid(self): |
| bottom_model = nn.Sequential( |
| nn.Linear(4, 8), |
| nn.Softmax(), |
| ).to(self.device_type) |
| |
| top_model = nn.Sequential( |
| nn.Linear(8, 2), |
| nn.Softmax(), |
| ).to(self.device_type) |
| |
| hybrid = nn.Sequential( |
| DDP(deepcopy(bottom_model)), |
| SPMD( |
| deepcopy(top_model), |
| schema=Schema( |
| mesh=DeviceMesh(self.device_type, torch.arange(self.world_size)), |
| placements=[Replicate()], |
| ), |
| ), |
| ) |
| ddp = DDP(nn.Sequential(deepcopy(bottom_model), deepcopy(top_model))) |
| input = torch.randn(12, 4).to(self.device_type) |
| |
| ddp(input).sum().backward() |
| hybrid(input).sum().backward() |
| for p1, p2 in zip(ddp.parameters(), hybrid.parameters()): |
| # DDP divides gradients by world size to compute average, but |
| # _Partial tensor shouldn't do that automatically. Hence explicitly |
| # do division here. |
| self.assertTrue( |
| p1.grad.allclose(p2.grad / self.world_size) or p1.grad.allclose(p2.grad) |
| ) |
| |
| |
| class DataDependentModule(nn.Module): |
| def __init__(self, world_size): |
| super().__init__() |
| self.world_size = world_size |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| raise RuntimeError( |
| "This eager implementation shouldn't be executed." |
| "This implementation is just an example of how to get around " |
| "data-dependant user-defined modules. " |
| ) |
| shape = x.shape |
| x = x.view(-1) |
| positive = x[x >= 0] |
| negative = x[x < 0] |
| |
| in_sizes = torch.tensor([positive.numel(), negative.numel()], dtype=torch.int32) |
| out_sizes = torch.empty_like(in_sizes) |
| dist.all_to_all_single( |
| out_sizes, |
| in_sizes, |
| output_split_sizes=[1, 1], |
| input_split_sizes=[1, 1], |
| ) |
| |
| xs = [positive, negative] |
| ys = [torch.Tensor(out_sizes[i].item()) for i in range(out_sizes.numel())] |
| dist.all_to_all(ys, xs) |
| |
| # some dummy compute |
| for y in ys: |
| y.add_(1) |
| |
| dist.all_to_all(xs, ys) |
| |
| return torch.cat(xs).reshape(shape) |
| |
| |
| class DummyModel(nn.Module): |
| def __init__(self, world_size): |
| super().__init__() |
| self.l1 = nn.Linear(10, 10) |
| self.ddm = DataDependentModule(world_size) |
| self.l2 = nn.Linear(10, 10) |
| self.relu = nn.ReLU() |
| |
| def forward(self, x): |
| assert len(x.size()) == 2 |
| |
| return self.relu(self.l2(self.ddm(self.l1(x)))) |
| |
| |
| def ddm(x: torch.Tensor) -> torch.Tensor: |
| return x |
| |
| |
| def ddm_backward(grad: torch.Tensor) -> torch.Tensor: |
| return grad |
| |
| |
| dummy_lib = torch.library.Library("dummy", "DEF") |
| dummy_lib.define("ddm(Tensor x) -> Tensor") |
| dummy_lib.impl("ddm", ddm, "CompositeExplicitAutograd") |
| dummy_lib.define("ddm_backward(Tensor x) -> Tensor") |
| dummy_lib.impl("ddm_backward", ddm_backward, "CompositeExplicitAutograd") |
| |
| |
| def _identity_prop_rule(op_schema: OpSchema) -> OutputSharding: |
| (x,) = op_schema.args_schema |
| assert isinstance(x, DTensorSpec), f"expecting DTensorSpec but got {x}" |
| |
| return OutputSharding(output_spec=DTensorSpec(x.mesh, x.placements)) |
| |
| |
| @register_prop_rule(torch.ops.dummy.ddm.default) |
| def _prop_ddm(op_schema: OpSchema) -> OutputSharding: |
| return _identity_prop_rule(op_schema) |
| |
| |
| @register_prop_rule(torch.ops.dummy.ddm_backward.default) |
| def _prop_ddm_backward(op_schema: OpSchema) -> OutputSharding: |
| return _identity_prop_rule(op_schema) |
| |
| |
| class DDMFunction(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor: |
| return torch.ops.dummy.ddm(x) |
| |
| @staticmethod |
| def backward(ctx: Any, grad_x: torch.Tensor) -> torch.Tensor: |
| return torch.ops.dummy.ddm_backward(grad_x) |
| |
| |
| class DummyDDM(nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| return DDMFunction.apply(x) |
| |
| |
| class TraceTrainStepTest(DTensorTestBase): |
| @property |
| def world_size(self): |
| return 2 |
| |
| @skip_if_lt_x_gpu(2) |
| @with_comms |
| def test_train_step_simple(self): |
| @compile() |
| def train_step(mod, inp): |
| mod(inp).sum().backward() |
| return [p.grad for p in mod.parameters()] |
| |
| inp = torch.randn(2, 10).cuda(self.rank) |
| # FIXME(@mrshenli): remove manual seed once dist.compile can synchronize |
| # module parameters. |
| torch.manual_seed(0) |
| mod = nn.Linear(10, 10).cuda(self.rank) |
| |
| ddp_mod = DDP(deepcopy(mod), device_ids=[self.rank]) |
| ddp_inp = deepcopy(inp) |
| |
| grads = train_step(mod, inp) |
| ddp_mod(ddp_inp).sum().backward() |
| |
| for g1, p2 in zip(grads, ddp_mod.parameters()): |
| # FIXME(@mrshenli): DDP by default divides gradients by world size. |
| # Should we match that behavior? |
| self.assertEqual(g1 / self.world_size, p2.grad) |
| |
| def _test_optimizer(self, mod, ddp_mod, opt, ddp_opt, inp, train_step): |
| ddp_inp = deepcopy(inp) |
| |
| # materialize optimizer states |
| mod(inp).sum().backward() |
| opt.step() |
| opt.zero_grad() |
| |
| ddp_mod(ddp_inp).sum().backward() |
| ddp_opt.step() |
| ddp_opt.zero_grad() |
| |
| # test parameter parity |
| train_step(mod, opt, inp) |
| |
| ddp_mod(ddp_inp).sum().backward() |
| # FIXME(@mrshenli): DDP by default divides grads by world size, but |
| # torch.distributed.compile does not do that yet. |
| with torch.no_grad(): |
| for p in ddp_mod.parameters(): |
| p.grad *= self.world_size |
| ddp_opt.step() |
| |
| for p1, p2 in zip(mod.parameters(), ddp_mod.parameters()): |
| self.assertEqual(p1, p2) |
| |
| @skip_if_lt_x_gpu(2) |
| @with_comms |
| def test_sgd(self): |
| @compile() |
| def train_step(mod, opt, inp): |
| mod(inp).sum().backward() |
| opt.step() |
| |
| # FIXME(@mrshenli): remove manual seed once dist.compile can synchronize |
| # module parameters. |
| torch.manual_seed(1) |
| mod = nn.Linear(10, 10, bias=True).cuda(self.rank) |
| opt = torch.optim.SGD(mod.parameters(), lr=0.01, foreach=True) |
| inp = torch.randn(2, 10).cuda(self.rank) |
| |
| ddp_mod = DDP(deepcopy(mod), device_ids=[self.rank]) |
| ddp_opt = torch.optim.SGD(ddp_mod.parameters(), lr=0.01, foreach=True) |
| self._test_optimizer(mod, ddp_mod, opt, ddp_opt, inp, train_step) |
| |
| def _test_adam(self, *, foreach: bool, fused: bool): |
| class AssertOverride(Override): |
| def __init__(self, outer): |
| self.outer = outer |
| |
| def replacement(self, orig_submodule: torch.nn.Module) -> torch.nn.Module: |
| return orig_submodule |
| |
| def transform( |
| self, |
| gm: fx.GraphModule, |
| flat_state: List[torch.Tensor], |
| ) -> fx.Graph: |
| # check dedup is successful, where there should only be 1 allreduce |
| self.outer.assertEqual( |
| len( |
| [ |
| n |
| for n in gm.graph.nodes |
| if n.target == torch.ops.c10d_functional.all_reduce.default |
| ] |
| ), |
| 1, |
| ) |
| |
| return gm |
| |
| @compile(module_override={nn.Linear: AssertOverride(self)}) |
| def train_step(mod, opt, inp): |
| mod(inp).sum().backward() |
| opt.step() |
| |
| # FIXME(@mrshenli): remove manual seed once dist.compile can synchronize |
| # module parameters. |
| torch.manual_seed(0) |
| # FIXME(@mrshenli): gradients for bias is missing |
| mod = nn.Sequential(nn.Linear(10, 10, bias=False)).cuda(self.rank) |
| opt = torch.optim.Adam( |
| mod.parameters(), |
| lr=0.01, |
| foreach=foreach, |
| fused=fused, |
| capturable=True, |
| ) |
| inp = torch.randn(2, 10).cuda(self.rank) |
| |
| ddp_mod = DDP(deepcopy(mod), device_ids=[self.rank]) |
| ddp_opt = torch.optim.Adam( |
| ddp_mod.parameters(), lr=0.01, foreach=foreach, fused=fused |
| ) |
| self._test_optimizer(mod, ddp_mod, opt, ddp_opt, inp, train_step) |
| |
| @skip_if_lt_x_gpu(2) |
| @with_comms |
| def test_adam_foreach(self): |
| self._test_adam(foreach=True, fused=False) |
| |
| @skip_if_lt_x_gpu(2) |
| @with_comms |
| def test_adam_fused(self): |
| self._test_adam(foreach=False, fused=True) |
| |
| @skip_if_lt_x_gpu(2) |
| @with_comms |
| def test_train_step_override(self): |
| transform_targets = [] |
| |
| class DDMOverride(Override): |
| def replacement(self, orig_submodule: torch.nn.Module) -> torch.nn.Module: |
| return DummyDDM() |
| |
| def transform( |
| self, |
| gm: fx.GraphModule, |
| flat_state: List[torch.Tensor], |
| ) -> fx.Graph: |
| nonlocal transform_targets |
| for node in gm.graph.nodes: |
| if node.target in [ |
| torch.ops.dummy.ddm.default, |
| torch.ops.dummy.ddm_backward.default, |
| ]: |
| transform_targets.append(node.target) |
| # N.B.: this is not a complete subgraph representing |
| # original logic, as we are testing the ability to |
| # modify graph after DTensor expansion. |
| with gm.graph.inserting_before(node): |
| new_node = gm.graph.call_function(torch.add, args=node.args) |
| node.replace_all_uses_with(new_node) |
| |
| gm.graph.lint() |
| gm.graph.eliminate_dead_code() |
| |
| return gm |
| |
| @compile(module_override={DataDependentModule: DDMOverride()}) |
| def train_step(mod, opt, inp): |
| mod(inp).sum().backward() |
| opt.step() |
| |
| mod = DummyModel(self.world_size).cuda(self.rank) |
| opt = torch.optim.SGD(mod.parameters(), lr=0.01, foreach=False) |
| # FIXME: symbolic tracing treats bs=1 as constant, have to use bs > 1. |
| inp = torch.randn(4, 10).cuda(self.rank) |
| train_step(mod, opt, inp) |
| |
| # checking transforms are indeed invoked. |
| self.assertEqual( |
| transform_targets, |
| [torch.ops.dummy.ddm.default, torch.ops.dummy.ddm_backward.default], |
| ) |
| |
| @skip_if_lt_x_gpu(2) |
| @with_comms |
| def test_gm_cache_and_transformation(self): |
| class GraphOptimization: |
| def __init__(self): |
| self.call_count = 0 |
| |
| def __call__(self, gm: fx.GraphModule) -> fx.GraphModule: |
| self.call_count += 1 |
| return gm |
| |
| graph_optimization = GraphOptimization() |
| |
| @compile(gm_transformation=graph_optimization) |
| def train_step(mod, opt, inp): |
| mod(inp).sum().backward() |
| opt.step() |
| |
| rank = torch.distributed.get_rank() |
| torch.manual_seed(0) |
| mod = nn.Linear(10, 10, bias=False).cuda(rank) |
| opt = torch.optim.Adam( |
| mod.parameters(), lr=0.01, foreach=False, capturable=True |
| ) |
| inp = torch.randn(2, 10).cuda(rank) |
| |
| # materialize optimizer states |
| mod(inp).sum().backward() |
| opt.step() |
| opt.zero_grad() |
| |
| train_step(mod, opt, inp) |
| self.assertEqual(graph_optimization.call_count, 1) |
| gm = train_step.__dict__[COMPILED_OBJECT_KEY].gm |
| train_step(mod, opt, inp) |
| self.assertEqual(id(gm), id(train_step.__dict__[COMPILED_OBJECT_KEY].gm)) |
| self.assertEqual(graph_optimization.call_count, 1) |
| |
| |
| class CoverageTest(DTensorTestBase): |
| @property |
| def world_size(self): |
| return 2 |
| |
| def _test_train_step(self, train_step, mod, *args): |
| ddp_mod = DDP(deepcopy(mod), device_ids=[self.rank]) |
| |
| opt = torch.optim.SGD(mod.parameters(), lr=0.01, foreach=True) |
| ddp_opt = torch.optim.SGD(ddp_mod.parameters(), lr=0.01, foreach=True) |
| |
| ddp_args = deepcopy(args) |
| |
| # materialize optimizer states |
| mod(*args).sum().backward() |
| opt.step() |
| opt.zero_grad() |
| |
| ddp_mod(*ddp_args).sum().backward() |
| ddp_opt.step() |
| ddp_opt.zero_grad() |
| |
| # test parameter parity |
| train_step(mod, opt, *args) |
| |
| ddp_mod(*ddp_args).sum().backward() |
| # FIXME(@mrshenli): DDP by default divides grads by world size, but |
| # torch.distributed.compile does not do that yet. |
| with torch.no_grad(): |
| for p in ddp_mod.parameters(): |
| p.grad *= self.world_size |
| ddp_opt.step() |
| |
| for p1, p2 in zip(mod.parameters(), ddp_mod.parameters()): |
| self.assertEqual(p1, p2) |
| |
| @skip_if_lt_x_gpu(2) |
| @with_comms |
| def test_log_softmax(self): |
| torch.manual_seed(0) |
| |
| @compile() |
| def train_step(mod, opt, inp): |
| mod(inp).sum().backward() |
| opt.step() |
| |
| mod = nn.Sequential( |
| nn.Linear(10, 10), |
| nn.LogSoftmax(dim=1), |
| ).cuda(self.rank) |
| inp = torch.randn(2, 10).cuda(self.rank) |
| self._test_train_step(train_step, mod, inp) |
| |
| @skip_if_lt_x_gpu(2) |
| @with_comms |
| def test_nll_loss(self): |
| class ModuleWithLoss(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mod = nn.Sequential( |
| nn.Linear(10, 10), |
| nn.LogSoftmax(dim=1), |
| ) |
| self.lss = nn.NLLLoss() |
| |
| def forward(self, x, tgt): |
| return self.lss(self.mod(x), tgt) |
| |
| torch.manual_seed(0) |
| mod = ModuleWithLoss().cuda(self.rank) |
| |
| @compile() |
| def train_step(mod, opt, inp, tgt): |
| mod(inp, tgt).backward() |
| opt.step() |
| |
| inp = torch.randn(2, 10).to(self.rank) |
| tgt = torch.empty(2, dtype=torch.long).random_(0, 10).to(self.rank) |
| |
| self._test_train_step(train_step, mod, inp, tgt) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |