| # Copyright 2019 Kakao Brain |
| # |
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| # |
| # This source code is licensed under the BSD license found in the |
| # LICENSE file in the root directory of this source tree. |
| import time |
| |
| import pytest |
| import torch |
| from torch import nn |
| |
| from torch.distributed.pipeline.sync._balance import balance_by_size, balance_by_time, blockpartition |
| from torch.distributed.pipeline.sync._balance.profile import layerwise_sandbox |
| |
| skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") |
| |
| devices = ["cpu"] |
| if torch.cuda.is_available(): |
| devices.append("cuda") |
| |
| |
| def test_blockpartition(): |
| assert blockpartition.solve([1, 2, 3, 4, 5, 6], partitions=2) == [[1, 2, 3, 4], [5, 6]] |
| |
| |
| def test_blockpartition_zeros(): |
| assert blockpartition.solve([0, 0], partitions=2) == [[0], [0]] |
| |
| |
| def test_blockpartition_non_positive_partitions(): |
| with pytest.raises(ValueError): |
| blockpartition.solve([42], partitions=0) |
| with pytest.raises(ValueError): |
| blockpartition.solve([42], partitions=-1) |
| |
| |
| def test_blockpartition_short_sequence(): |
| with pytest.raises(ValueError): |
| blockpartition.solve([], partitions=1) |
| with pytest.raises(ValueError): |
| blockpartition.solve([42], partitions=2) |
| |
| |
| @pytest.mark.parametrize("device", devices) |
| @pytest.mark.skip(reason="Flaky due to time.sleep()") |
| def test_balance_by_time(device): |
| class Delay(nn.Module): |
| def __init__(self, seconds): |
| super().__init__() |
| self.seconds = seconds |
| |
| def forward(self, x): |
| time.sleep(self.seconds) |
| return x |
| |
| model = nn.Sequential(*[Delay(i / 10) for i in [1, 2, 3, 4, 5, 6]]) |
| sample = torch.rand(1) |
| balance = balance_by_time(2, model, sample, device=device) |
| assert balance == [4, 2] |
| |
| |
| def test_balance_by_time_loop_resets_input(): |
| # nn.Flatten was introduced at PyTorch 1.2.0. |
| class Flatten(nn.Module): |
| def forward(self, x): |
| return x.flatten(1) |
| |
| model = nn.Sequential(nn.Conv2d(3, 2, 1), Flatten(), nn.Linear(128, 10)) |
| sample = torch.rand(10, 3, 8, 8) |
| balance = balance_by_time(2, model, sample, device="cpu") |
| assert balance == [1, 2] |
| |
| |
| @skip_if_no_cuda |
| def test_balance_by_size_latent(): |
| class Expand(nn.Module): |
| def __init__(self, times): |
| super().__init__() |
| self.times = times |
| |
| def forward(self, x): |
| for i in range(self.times): |
| x = x + torch.rand_like(x, requires_grad=True) |
| return x |
| |
| sample = torch.rand(10, 100, 100) |
| |
| model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]]) |
| balance = balance_by_size(2, model, sample) |
| assert balance == [4, 2] |
| |
| model = nn.Sequential(*[Expand(i) for i in [6, 5, 4, 3, 2, 1]]) |
| balance = balance_by_size(2, model, sample) |
| assert balance == [2, 4] |
| |
| |
| @skip_if_no_cuda |
| def test_balance_by_size_param(): |
| model = nn.Sequential(*[nn.Linear(i + 1, i + 2) for i in range(6)]) |
| sample = torch.rand(7, 1) |
| balance = balance_by_size(2, model, sample, param_scale=100) |
| assert balance == [4, 2] |
| |
| model = nn.Sequential(*[nn.Linear(i + 2, i + 1) for i in reversed(range(6))]) |
| sample = torch.rand(1, 7) |
| balance = balance_by_size(2, model, sample, param_scale=100) |
| assert balance == [2, 4] |
| |
| |
| @skip_if_no_cuda |
| def test_balance_by_size_param_scale(): |
| class Tradeoff(nn.Module): |
| def __init__(self, param_size, latent_size): |
| super().__init__() |
| self.fc = nn.Linear(param_size, param_size) |
| self.latent_size = latent_size |
| |
| def forward(self, x): |
| for i in range(self.latent_size): |
| x = x + torch.rand_like(x, requires_grad=True) |
| return x |
| |
| model = nn.Sequential( |
| Tradeoff(param_size=1, latent_size=6), |
| Tradeoff(param_size=2, latent_size=5), |
| Tradeoff(param_size=3, latent_size=4), |
| Tradeoff(param_size=4, latent_size=3), |
| Tradeoff(param_size=5, latent_size=2), |
| Tradeoff(param_size=6, latent_size=1), |
| ) |
| |
| sample = torch.rand(1, requires_grad=True) |
| |
| balance = balance_by_size(2, model, sample, param_scale=0) |
| assert balance == [2, 4] |
| |
| balance = balance_by_size(2, model, sample, param_scale=100) |
| assert balance == [4, 2] |
| |
| |
| @pytest.mark.parametrize("device", devices) |
| def test_layerwise_sandbox(device): |
| model = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) |
| model.eval() |
| |
| for layer in layerwise_sandbox(model, torch.device(device)): |
| assert layer.training |
| assert all(p.device.type == device for p in layer.parameters()) |
| |
| assert all(not l.training for l in model) |
| assert all(p.device.type == "cpu" for p in model.parameters()) |
| |
| |
| @pytest.mark.parametrize("device", devices) |
| def test_sandbox_during_profiling(device): |
| model = nn.Sequential(nn.BatchNorm2d(3)) |
| |
| before = {k: v.clone() for k, v in model.state_dict().items()} |
| |
| sample = torch.rand(1, 3, 10, 10) |
| balance_by_time(1, model, sample, device=device) |
| |
| after = model.state_dict() |
| |
| assert before.keys() == after.keys() |
| for key, value in before.items(): |
| assert torch.allclose(after[key], value), key |
| |
| |
| def test_not_training(): |
| class AssertTraining(nn.Module): |
| def forward(self, x): |
| assert self.training |
| return x |
| |
| model = nn.Sequential(AssertTraining()) |
| |
| model.eval() |
| assert not model.training |
| |
| sample = torch.rand(1) |
| balance_by_time(1, model, sample, device="cpu") |
| |
| assert not model.training |
| |
| |
| def test_balance_by_time_tuple(): |
| class Twin(nn.Module): |
| def forward(self, x): |
| return x, x.detach() |
| |
| class Add(nn.Module): |
| def forward(self, a_b): |
| a, b = a_b |
| return a + b |
| |
| model = nn.Sequential(Twin(), Add()) |
| sample = torch.rand(1, requires_grad=True) |
| balance_by_time(1, model, sample, device="cpu") |
| |
| |
| @skip_if_no_cuda |
| def test_balance_by_size_tuple(): |
| class Twin(nn.Module): |
| def forward(self, x): |
| return x, x.detach() |
| |
| class Add(nn.Module): |
| def forward(self, a_b): |
| a, b = a_b |
| return a + b |
| |
| model = nn.Sequential(Twin(), Add()) |
| sample = torch.rand(1, requires_grad=True) |
| balance_by_size(1, model, sample) |
| |
| |
| def test_already_has_grad(): |
| model = nn.Sequential(nn.Conv2d(3, 3, 1)) |
| sample = torch.rand(1, 3, 32, 32) |
| model(sample).norm().backward() |
| |
| with pytest.raises(ValueError, match="some parameter already has gradient"): |
| balance_by_time(1, model, sample, device="cpu") |