|  | # Owner(s): ["oncall: distributed"] | 
|  |  | 
|  | # Copyright 2019 Kakao Brain | 
|  | # | 
|  | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | 
|  | # | 
|  | # This source code is licensed under the BSD license found in the | 
|  | # LICENSE file in the root directory of this source tree. | 
|  | from collections import OrderedDict | 
|  | from copy import deepcopy | 
|  | import time | 
|  |  | 
|  | import pytest | 
|  | import random | 
|  | import torch | 
|  | from torch import nn | 
|  | from torch import Tensor | 
|  |  | 
|  | from torch.distributed.pipeline.sync import Pipe, NoChunk, WithDevice | 
|  | from torch.distributed.pipeline.sync.pipe import PipeSequential | 
|  | from torch.testing._internal.common_utils import run_tests | 
|  |  | 
|  | skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") | 
|  |  | 
|  |  | 
|  | def test_pipe_without_rpc(): | 
|  | model = nn.Sequential(nn.Linear(1, 1)) | 
|  | with pytest.raises(RuntimeError, match='Please initialize RPC framework'): | 
|  | pipe = Pipe(model, chunks=1) | 
|  |  | 
|  | def test_parameters(setup_rpc): | 
|  | model = nn.Sequential(nn.Linear(1, 1)) | 
|  | pipe = Pipe(model, chunks=1) | 
|  | assert list(pipe.parameters()) != [] | 
|  |  | 
|  |  | 
|  | def test_public_attrs(setup_rpc): | 
|  | class MyString: | 
|  | def __init__(self, value): | 
|  | self.value = value | 
|  |  | 
|  | def __str__(self): | 
|  | return self.value | 
|  |  | 
|  | model = nn.Sequential(nn.Linear(1, 1)) | 
|  | pipe = Pipe(model, chunks=42.000, checkpoint=MyString("always")) | 
|  |  | 
|  | assert pipe.devices == [torch.device("cpu")] | 
|  | assert pipe.chunks == 42 | 
|  | assert isinstance(pipe.chunks, int) | 
|  | assert pipe.checkpoint == "always" | 
|  | assert isinstance(pipe.checkpoint, str) | 
|  |  | 
|  |  | 
|  | def test_sequential_like(setup_rpc): | 
|  | a = nn.Linear(1, 1) | 
|  | b = nn.Linear(1, 1) | 
|  |  | 
|  | model = nn.Sequential(a, b) | 
|  | model = Pipe(model) | 
|  |  | 
|  | assert len(model) == 2 | 
|  | assert list(model) == [a, b] | 
|  |  | 
|  | assert model[0] is a | 
|  | assert model[1] is b | 
|  | with pytest.raises(IndexError): | 
|  | _ = model[2] | 
|  |  | 
|  | assert model[-1] is b | 
|  | assert model[-2] is a | 
|  |  | 
|  | def test_chunks_less_than_1(setup_rpc): | 
|  | model = nn.Sequential(nn.Linear(1, 1)) | 
|  |  | 
|  | with pytest.raises(ValueError): | 
|  | Pipe(model, chunks=0) | 
|  |  | 
|  | with pytest.raises(ValueError): | 
|  | Pipe(model, chunks=-1) | 
|  |  | 
|  | def test_batch_size_indivisible(setup_rpc): | 
|  | model = nn.Sequential(nn.Linear(1, 1)) | 
|  | model = Pipe(model, chunks=4) | 
|  |  | 
|  | with pytest.warns(None) as record: | 
|  | model(torch.rand(7, 1)) | 
|  |  | 
|  | # Indivisible batch size is legal. | 
|  | assert not record | 
|  |  | 
|  |  | 
|  | def test_batch_size_small(setup_rpc): | 
|  | model = nn.Sequential(nn.Linear(1, 1)) | 
|  | model = Pipe(model, chunks=4) | 
|  |  | 
|  | with pytest.warns(None) as record: | 
|  | model(torch.rand(2, 1)) | 
|  |  | 
|  | # Batch size smaller than chunks is legal. | 
|  | assert not record | 
|  |  | 
|  |  | 
|  | def test_checkpoint_mode(setup_rpc): | 
|  | def count_grad_fn(grad_fn, name, visited=None): | 
|  | if visited is None: | 
|  | visited = set() | 
|  | if grad_fn in visited: | 
|  | return 0 | 
|  | visited.add(grad_fn) | 
|  |  | 
|  | if grad_fn is None: | 
|  | return 0 | 
|  | if grad_fn.__class__.__name__ == name: | 
|  | return 1 | 
|  |  | 
|  | counter = 0 | 
|  | for next_grad_fn, _ in grad_fn.next_functions: | 
|  | counter += count_grad_fn(next_grad_fn, name, visited=visited) | 
|  | return counter | 
|  |  | 
|  | model = nn.Sequential(nn.Linear(1, 1)) | 
|  | input = torch.rand(2, 1) | 
|  |  | 
|  | always = Pipe(model, chunks=2, checkpoint="always") | 
|  | except_last = Pipe(model, chunks=2, checkpoint="except_last") | 
|  | never = Pipe(model, chunks=2, checkpoint="never") | 
|  |  | 
|  | always_output = always(input) | 
|  | except_last_output = except_last(input) | 
|  | never_output = never(input) | 
|  |  | 
|  | assert count_grad_fn(always_output.local_value().grad_fn, "CheckpointBackward") == 2 | 
|  | assert count_grad_fn(except_last_output.local_value().grad_fn, "CheckpointBackward") == 1 | 
|  | assert count_grad_fn(never_output.local_value().grad_fn, "CheckpointBackward") == 0 | 
|  |  | 
|  |  | 
|  | def test_checkpoint_mode_invalid(setup_rpc): | 
|  | model = nn.Sequential(nn.Linear(1, 1)) | 
|  |  | 
|  | with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"): | 
|  | Pipe(model, chunks=2, checkpoint="INVALID_CHECKPOINT") | 
|  |  | 
|  |  | 
|  | def test_checkpoint_mode_when_chunks_1(setup_rpc): | 
|  | model = nn.Sequential(nn.Linear(1, 1)) | 
|  |  | 
|  | # All checkpoint modes are fine. | 
|  | Pipe(model, chunks=1, checkpoint="except_last") | 
|  | Pipe(model, chunks=1, checkpoint="always") | 
|  | Pipe(model, chunks=1, checkpoint="never") | 
|  |  | 
|  |  | 
|  | def test_checkpoint_eval(setup_rpc): | 
|  | model = nn.Sequential(nn.Linear(1, 1)) | 
|  | model = Pipe(model, chunks=2) | 
|  | input = torch.rand(2, 1) | 
|  |  | 
|  | def find_grad_fn(grad_fn, name): | 
|  | if grad_fn is None: | 
|  | return False | 
|  | if grad_fn.__class__.__name__ == name: | 
|  | return True | 
|  | for next_grad_fn, _ in grad_fn.next_functions: | 
|  | if find_grad_fn(next_grad_fn, name): | 
|  | return True | 
|  | return False | 
|  |  | 
|  | model.train() | 
|  | train_output = model(input) | 
|  | assert find_grad_fn(train_output.local_value().grad_fn, "CheckpointBackward") | 
|  | assert find_grad_fn(train_output.local_value().grad_fn, "RecomputeBackward") | 
|  |  | 
|  | model.eval() | 
|  | eval_output = model(input) | 
|  | assert not find_grad_fn(eval_output.local_value().grad_fn, "CheckpointBackward") | 
|  | assert not find_grad_fn(eval_output.local_value().grad_fn, "RecomputeBackward") | 
|  |  | 
|  |  | 
|  | def test_checkpoint_non_float_input(setup_rpc): | 
|  | class ForkNonFloat(nn.Module): | 
|  | def forward(self, input): | 
|  | return (input * 2, torch.tensor([False])) | 
|  |  | 
|  | class JoinNonFloat(nn.Module): | 
|  | def forward(self, input, non_float): | 
|  | return input * 2 | 
|  |  | 
|  | model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) | 
|  | model = Pipe(model, chunks=1, checkpoint="always") | 
|  |  | 
|  | input = torch.rand(1, requires_grad=True) | 
|  | output = model(input) | 
|  | output.backward() | 
|  |  | 
|  |  | 
|  | def test_no_grad(setup_rpc): | 
|  | model = nn.Sequential(nn.Linear(1, 1)) | 
|  | model = Pipe(model, chunks=2) | 
|  | input = torch.rand(2, 1) | 
|  |  | 
|  | latent = None | 
|  |  | 
|  | def hook(module, input, output): | 
|  | _ = module | 
|  | _ = input | 
|  |  | 
|  | nonlocal latent | 
|  | latent = output | 
|  |  | 
|  | partition = model.partitions[0] | 
|  | partition.register_forward_hook(hook) | 
|  |  | 
|  | with torch.no_grad(): | 
|  | model(input) | 
|  |  | 
|  | assert latent.grad_fn is None | 
|  |  | 
|  |  | 
|  | def test_exception(setup_rpc): | 
|  | class ExpectedException(Exception): | 
|  | pass | 
|  |  | 
|  | class Raise(nn.Module): | 
|  | def forward(self, *_): | 
|  | raise ExpectedException() | 
|  |  | 
|  | model = nn.Sequential(Raise()) | 
|  | model = Pipe(model, chunks=1) | 
|  |  | 
|  | with pytest.raises(ExpectedException): | 
|  | model(torch.rand(1)) | 
|  |  | 
|  |  | 
|  | def test_exception_early_stop_asap(setup_rpc): | 
|  | """Even the first partitions have finished to process, the partition before | 
|  | the failed partition should be killed as soon as possible. | 
|  | """ | 
|  |  | 
|  | class ExpectedException(Exception): | 
|  | pass | 
|  |  | 
|  | class Pass(nn.Module): | 
|  | def forward(self, x): | 
|  | return x | 
|  |  | 
|  | counter = 0 | 
|  |  | 
|  | class Counter(nn.Module): | 
|  | def forward(self, x): | 
|  | time.sleep(0.1) | 
|  |  | 
|  | nonlocal counter | 
|  | counter += 1 | 
|  |  | 
|  | return x | 
|  |  | 
|  | class Raise(nn.Module): | 
|  | def forward(self, x): | 
|  | raise ExpectedException() | 
|  |  | 
|  | model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) | 
|  | model = Pipe(model, chunks=3) | 
|  |  | 
|  | with pytest.raises(ExpectedException): | 
|  | model(torch.rand(3)) | 
|  |  | 
|  | # If the early stop doesn't work, it would be 3 instead. | 
|  | assert counter == 2 | 
|  |  | 
|  |  | 
|  | def test_nested_input(setup_rpc): | 
|  | class NestedInput(nn.Module): | 
|  | def __init__(self): | 
|  | super().__init__() | 
|  | self.fc_a = nn.Linear(1, 1) | 
|  | self.fc_b = nn.Linear(1, 1) | 
|  |  | 
|  | def forward(self, inp): | 
|  | return inp | 
|  |  | 
|  | model = nn.Sequential(NestedInput()) | 
|  | model = Pipe(model, chunks=2) | 
|  |  | 
|  | a = torch.rand(10, 1, requires_grad=True) | 
|  | b = torch.rand(10, 1, requires_grad=True) | 
|  |  | 
|  | # TypeError: expected Tensor, but got tuple | 
|  | with pytest.raises(TypeError): | 
|  | model((a, (a, b))).local_value() | 
|  |  | 
|  | # TypeError: expected Tensor, but got list | 
|  | with pytest.raises(TypeError): | 
|  | model((a, [a, b])).local_value() | 
|  |  | 
|  |  | 
|  | def test_input_pair(setup_rpc): | 
|  | class Two(nn.Module): | 
|  | def __init__(self): | 
|  | super().__init__() | 
|  | self.fc_a = nn.Linear(1, 1) | 
|  | self.fc_b = nn.Linear(1, 1) | 
|  |  | 
|  | def forward(self, a, b): | 
|  | return (self.fc_a(a), self.fc_b(b)) | 
|  |  | 
|  | model = nn.Sequential(Two()) | 
|  | model = Pipe(model, chunks=2) | 
|  |  | 
|  | a = torch.rand(10, 1, requires_grad=True) | 
|  | b = torch.rand(10, 1, requires_grad=True) | 
|  |  | 
|  | a_out, b_out = model(a, b).local_value() | 
|  | loss = (a_out + b_out).mean() | 
|  | loss.backward() | 
|  |  | 
|  | assert a.grad is not None | 
|  | assert b.grad is not None | 
|  |  | 
|  | def test_multi_sequence_input(setup_rpc): | 
|  | class MultiSeq(nn.Module): | 
|  | def forward(self, tup1, tup2): | 
|  | return tup1, tup2 | 
|  |  | 
|  | model = Pipe(nn.Sequential(MultiSeq())) | 
|  | with pytest.raises(TypeError): | 
|  | model( | 
|  | [torch.rand(10), torch.rand(10)], | 
|  | [torch.rand(10), torch.rand(10)] | 
|  | ) | 
|  |  | 
|  | def test_input_singleton(setup_rpc): | 
|  | class One(nn.Module): | 
|  | def __init__(self): | 
|  | super().__init__() | 
|  | self.fc = nn.Linear(1, 1) | 
|  |  | 
|  | def forward(self, a): | 
|  | return (self.fc(a),) | 
|  |  | 
|  | model = nn.Sequential(One()) | 
|  | model = Pipe(model, chunks=2) | 
|  |  | 
|  | a = torch.rand(10, 1, requires_grad=True) | 
|  |  | 
|  | (a_out,) = model(a).local_value() | 
|  | loss = a_out.mean() | 
|  | loss.backward() | 
|  |  | 
|  | assert all(p.grad is not None for p in model.parameters()) | 
|  | assert a.grad is not None | 
|  |  | 
|  |  | 
|  | def test_input_varargs(setup_rpc): | 
|  | model = nn.Sequential(nn.Linear(1, 1)) | 
|  | model = Pipe(model) | 
|  |  | 
|  | a = torch.rand(1) | 
|  | b = torch.rand(1) | 
|  |  | 
|  | # TypeError: forward() takes 2 positional arguments but 3 were given | 
|  | with pytest.raises(TypeError): | 
|  | model(a, b) | 
|  |  | 
|  |  | 
|  | def test_non_tensor(setup_rpc): | 
|  | class NonTensor(nn.Module): | 
|  | def forward(self, _): | 
|  | return "hello" | 
|  |  | 
|  | model = nn.Sequential(NonTensor()) | 
|  | model = Pipe(model) | 
|  | x = torch.rand(1) | 
|  |  | 
|  | with pytest.raises(TypeError): | 
|  | model(x) | 
|  |  | 
|  | with pytest.raises(TypeError): | 
|  | model("hello") | 
|  |  | 
|  |  | 
|  | def test_non_tensor_sequence(setup_rpc): | 
|  | class NonTensorTuple(nn.Module): | 
|  | def forward(self, x): | 
|  | return (x, "hello") | 
|  |  | 
|  | class NonTensorArgs(nn.Module): | 
|  | def forward(self, x: str, y: bool): | 
|  | return x, y | 
|  |  | 
|  | model = nn.Sequential(NonTensorTuple()) | 
|  | model = Pipe(model) | 
|  | x = torch.rand(1) | 
|  |  | 
|  | with pytest.raises(TypeError): | 
|  | model((x, "hello")) | 
|  |  | 
|  | with pytest.raises(TypeError): | 
|  | model([x, "hello"]) | 
|  |  | 
|  | model = nn.Sequential(NonTensorArgs()) | 
|  | model = Pipe(model) | 
|  |  | 
|  | with pytest.raises(TypeError): | 
|  | # Need atleast one Tensor. | 
|  | model("hello", True) | 
|  |  | 
|  |  | 
|  | @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) | 
|  | def test_valid_non_tensor(checkpoint, setup_rpc): | 
|  | class NonTensor1(nn.Module): | 
|  | def forward(self, a: int, b: Tensor, c: bool, d: Tensor): | 
|  | res = b + a if c else b * a | 
|  | if d is not None: | 
|  | res += d | 
|  | return res, c, a, b, "hello", d | 
|  |  | 
|  | class NonTensor2(nn.Module): | 
|  | def forward(self, a: Tensor, b: bool, c: int, d: Tensor, e: str, f: Tensor): | 
|  | res = a * c if b else a + c | 
|  | res += d | 
|  | return c, res, a, d + f if f is not None else d, b, e, f | 
|  |  | 
|  | model = Pipe(nn.Sequential(NonTensor1(), NonTensor2()), chunks=5, checkpoint=checkpoint) | 
|  | a = random.randint(0, 10) | 
|  | b = torch.rand(10, 10) | 
|  | c = random.randint(0, 1) == 0 | 
|  | d = torch.rand(10, 10) | 
|  | res = model(a, b, c, d).local_value() | 
|  | assert 7 == len(res) | 
|  | assert [a] * 5 == res[0] | 
|  | if c: | 
|  | assert torch.allclose(((b + a + d) * a) + b, res[1]) | 
|  | assert torch.allclose(b + a + d, res[2]) | 
|  | else: | 
|  | assert torch.allclose(((b * a) + d + a) + b, res[1]) | 
|  | assert torch.allclose(b * a + d, res[2]) | 
|  | assert torch.allclose(b + d, res[3]) | 
|  | assert [c] * 5 == res[4] | 
|  | assert ["hello"] * 5 == res[5] | 
|  | assert torch.allclose(d, res[6]) | 
|  |  | 
|  | # Test one of the tensors can be None | 
|  | res = model(a, b, c, None).local_value() | 
|  | assert 7 == len(res) | 
|  | assert [a] * 5 == res[0] | 
|  | if c: | 
|  | assert torch.allclose(((b + a) * a) + b, res[1]) | 
|  | assert torch.allclose(b + a, res[2]) | 
|  | else: | 
|  | assert torch.allclose(((b * a) + a) + b, res[1]) | 
|  | assert torch.allclose(b * a, res[2]) | 
|  | assert torch.allclose(b, res[3]) | 
|  | assert [c] * 5 == res[4] | 
|  | assert ["hello"] * 5 == res[5] | 
|  | assert [None] * 5 == res[6] | 
|  |  | 
|  | # Need atleast one tensor. | 
|  | with pytest.raises(TypeError): | 
|  | model(a, None, c, None) | 
|  |  | 
|  | @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) | 
|  | def test_no_tensor_output(checkpoint, setup_rpc): | 
|  | class Model1(nn.Module): | 
|  | def forward(self, a: int, b: Tensor, c: bool): | 
|  | return a, c, "hello" | 
|  |  | 
|  | class Model2(nn.Module): | 
|  | def forward(self, a: int, b: bool, c: str): | 
|  | return a, c, b | 
|  |  | 
|  | model = Pipe(nn.Sequential(Model1(), Model2()), chunks=5) | 
|  | a = random.randint(0, 10) | 
|  | b = torch.rand(10, 10) | 
|  | c = random.randint(0, 1) == 0 | 
|  |  | 
|  | # Need atleast one tensor across partitions too. | 
|  | with pytest.raises(TypeError): | 
|  | res = model(a, b, c).local_value() | 
|  |  | 
|  |  | 
|  | @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) | 
|  | def test_uneven_batch_size(checkpoint, setup_rpc): | 
|  | class Model(nn.Module): | 
|  | def forward(self, a: Tensor, b: int, c: Tensor): | 
|  | return a, b, c | 
|  |  | 
|  | model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) | 
|  | a = torch.rand(3, 10) | 
|  | b = random.randint(0, 10) | 
|  | c = torch.rand(6, 10) | 
|  | res = model(a, b, c).local_value() | 
|  | assert torch.allclose(a, res[0]) | 
|  | assert [b] * 3 == res[1]  # 3 chunks | 
|  | assert torch.allclose(c, res[2]) | 
|  |  | 
|  | # Two tensors producing uneven chunks would fail. | 
|  | model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) | 
|  | a = torch.rand(3, 10) | 
|  | b = random.randint(0, 10) | 
|  | c = torch.rand(4, 10) | 
|  |  | 
|  | with pytest.raises(RuntimeError, match='Found different number of chunks'): | 
|  | model(a, b, c) | 
|  |  | 
|  | @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) | 
|  | def test_no_chunk(checkpoint, setup_rpc): | 
|  | class Model(nn.Module): | 
|  | def forward(self, a: Tensor, b: int, c: Tensor): | 
|  | return a, b, c | 
|  |  | 
|  | model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5) | 
|  | a = torch.rand(10, 10) | 
|  | b = random.randint(0, 10) | 
|  | c = torch.rand(10, 10) | 
|  | res = model(a, b, NoChunk(c)).local_value() | 
|  | assert torch.allclose(a, res[0]) | 
|  | assert [b] * 5 == res[1] | 
|  | # c gets replicated due to NoChunk and the same tensor gets concatenated 5 | 
|  | # times in the output. | 
|  | assert torch.allclose(torch.cat((c, c, c, c, c)), res[2]) | 
|  |  | 
|  | # Test invalid type for NoChunk | 
|  | with pytest.raises(TypeError, match='NoChunk only supported for tensors'): | 
|  | NoChunk(b) | 
|  |  | 
|  |  | 
|  | @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) | 
|  | def test_deferred_batch_norm(checkpoint, setup_rpc): | 
|  | bn = nn.BatchNorm2d(3) | 
|  | pipe_bn = deepcopy(bn) | 
|  | pipe = Pipe( | 
|  | nn.Sequential(pipe_bn), chunks=2, checkpoint=checkpoint, deferred_batch_norm=True | 
|  | ) | 
|  |  | 
|  | x = torch.rand(4, 3, 10, 10) | 
|  | pipe(x).local_value().mean().backward() | 
|  | bn(x).mean().backward() | 
|  |  | 
|  | assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4) | 
|  | assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4) | 
|  |  | 
|  |  | 
|  | @pytest.mark.parametrize("checkpoint", ["never", "always"]) | 
|  | def test_deferred_batch_norm_params(checkpoint, setup_rpc): | 
|  | bn = nn.BatchNorm2d(3) | 
|  | pipe_bn = deepcopy(bn) | 
|  | pipe = Pipe( | 
|  | nn.Sequential(pipe_bn), chunks=1, checkpoint=checkpoint, deferred_batch_norm=True | 
|  | ) | 
|  |  | 
|  | x = torch.rand(4, 3, 10, 10) | 
|  | pipe(x).local_value().mean().backward() | 
|  | bn(x).mean().backward() | 
|  |  | 
|  | assert pipe[0].weight.grad is not None | 
|  | assert pipe[0].bias.grad is not None | 
|  |  | 
|  | assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4) | 
|  | assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4) | 
|  |  | 
|  |  | 
|  | def test_devices(setup_rpc): | 
|  | a = nn.Linear(1, 1) | 
|  | b = nn.Linear(1, 1) | 
|  | c = nn.Linear(1, 1) | 
|  |  | 
|  | # There are extra two devices. | 
|  | model = nn.Sequential(a, b, c) | 
|  | model = Pipe(model) | 
|  |  | 
|  | cpu = torch.device("cpu") | 
|  | # Extra devices must be discarded. | 
|  | assert model.devices == [cpu, cpu, cpu] | 
|  |  | 
|  |  | 
|  | def test_partitions(setup_rpc): | 
|  | a = nn.Linear(1, 1) | 
|  | b = nn.Linear(1, 1) | 
|  |  | 
|  | model = nn.Sequential(a, b) | 
|  | model = Pipe(model) | 
|  |  | 
|  | assert isinstance(model.partitions, nn.ModuleList) | 
|  | assert isinstance(model.partitions[0], nn.Sequential) | 
|  | assert isinstance(model.partitions[1], nn.Sequential) | 
|  |  | 
|  | assert "partitions.0.0.weight" in model.state_dict() | 
|  |  | 
|  |  | 
|  | @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") | 
|  | def test_merged_partitions(setup_rpc): | 
|  | a = nn.Linear(1, 1).to(0) | 
|  | b = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 2)).to(0) | 
|  | c = nn.Linear(1, 1) | 
|  | d = nn.Linear(1, 2) | 
|  |  | 
|  | model = nn.Sequential(a, b, c, d) | 
|  | model = Pipe(model) | 
|  |  | 
|  | assert isinstance(model.partitions, nn.ModuleList) | 
|  | assert isinstance(model.partitions[0], PipeSequential) | 
|  | assert isinstance(model.partitions[1], PipeSequential) | 
|  | assert list(model.partitions[0]) == [a, b[0], b[1]] | 
|  | assert list(model.partitions[1]) == [c] | 
|  | assert list(model.partitions[2]) == [d] | 
|  |  | 
|  |  | 
|  | def test_deny_moving(setup_rpc): | 
|  | a = nn.Linear(1, 1) | 
|  | b = nn.Linear(1, 1) | 
|  |  | 
|  | model = nn.Sequential(a, b) | 
|  | model = Pipe(model) | 
|  |  | 
|  | # Moving is denied. | 
|  | with pytest.raises(TypeError): | 
|  | model.cuda() | 
|  |  | 
|  | with pytest.raises(TypeError): | 
|  | model.cpu() | 
|  |  | 
|  | with pytest.raises(TypeError): | 
|  | model.to(torch.device("cuda")) | 
|  |  | 
|  | with pytest.raises(TypeError): | 
|  | model.to(0) | 
|  |  | 
|  | with pytest.raises(TypeError): | 
|  | model.to("cuda") | 
|  |  | 
|  | with pytest.raises(TypeError): | 
|  | model.to(device=0) | 
|  |  | 
|  | with pytest.raises(TypeError): | 
|  | model.to(torch.rand(1)) | 
|  |  | 
|  | with pytest.raises(TypeError): | 
|  | model.to(tensor=torch.rand(1)) | 
|  |  | 
|  | # Casting is allowed. | 
|  | model.half() | 
|  | model.to(torch.double) | 
|  | model.to(dtype=torch.float) | 
|  |  | 
|  |  | 
|  | def test_empty_module(setup_rpc): | 
|  | # Empty sequential module is not illegal. | 
|  | model = nn.Sequential() | 
|  | model = Pipe(model) | 
|  |  | 
|  | assert model(torch.tensor(42)).local_value() == torch.tensor(42) | 
|  |  | 
|  | # But only tensor or tensors is legal in Pipe. | 
|  | with pytest.raises(TypeError): | 
|  | model(42) | 
|  |  | 
|  |  | 
|  | def test_named_children(setup_rpc): | 
|  | a = nn.Linear(1, 1) | 
|  | b = nn.Linear(1, 1) | 
|  |  | 
|  | model = nn.Sequential(OrderedDict([("a", a), ("b", b)])) | 
|  | model = Pipe(model) | 
|  |  | 
|  | names = {n for n, _ in model.named_modules()} | 
|  | assert "partitions.0.0" in names | 
|  | assert "partitions.1.0" in names | 
|  |  | 
|  | # Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires | 
|  | # several methods in its namespace. | 
|  | with pytest.raises(AttributeError): | 
|  | model.a | 
|  |  | 
|  |  | 
|  | def test_verify_module_non_sequential(setup_rpc): | 
|  | with pytest.raises(TypeError, match="module must be nn.Sequential to be partitioned"): | 
|  | Pipe(nn.Module()) | 
|  |  | 
|  |  | 
|  | def test_verify_module_duplicate_children(setup_rpc): | 
|  | conv = nn.Conv2d(3, 3, 1) | 
|  | model = nn.Sequential(conv, conv) | 
|  |  | 
|  | with pytest.raises(ValueError, match="module with duplicate children is not supported"): | 
|  | Pipe(model) | 
|  |  | 
|  |  | 
|  | @skip_if_no_cuda | 
|  | def test_verify_module_params_on_same_device(setup_rpc): | 
|  | class Surrogate(nn.Module): | 
|  | def __init__(self, param1, param2): | 
|  | super().__init__() | 
|  | self.param1 = param1 | 
|  | self.param2 = param2 | 
|  |  | 
|  | conv1 = nn.Conv2d(3, 3, 1) | 
|  | conv2 = nn.Conv2d(3, 3, 1) | 
|  | model = nn.Sequential(Surrogate(conv1, conv2.cuda())) | 
|  |  | 
|  | with pytest.raises( | 
|  | ValueError, | 
|  | match=r'should have all parameters on a single device, please use .to\(\)' | 
|  | ' to place the module on a single device'): | 
|  | Pipe(model) | 
|  |  | 
|  | @skip_if_no_cuda | 
|  | @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs") | 
|  | def test_verify_nested_modules(setup_rpc): | 
|  | model = nn.Sequential( | 
|  | nn.Sequential( | 
|  | nn.Linear(32, 16).cuda(0), | 
|  | nn.Linear(16, 8).cuda(0) | 
|  | ), | 
|  | nn.Sequential( | 
|  | nn.Linear(8, 4).cuda(1), | 
|  | nn.Linear(4, 2).cuda(1) | 
|  | ), | 
|  | ) | 
|  |  | 
|  | pipe = Pipe(model) | 
|  | out = pipe(torch.rand(10, 32).cuda(0)) | 
|  | assert out.local_value().device == torch.device("cuda:1") | 
|  | assert out.local_value().size() == torch.Size([10, 2]) | 
|  |  | 
|  | def test_verify_module_duplicate_parameters_on_same_device(setup_rpc): | 
|  | class Surrogate(nn.Module): | 
|  | def __init__(self, module): | 
|  | super().__init__() | 
|  | self.module = module | 
|  |  | 
|  | conv = nn.Conv2d(3, 3, 1) | 
|  | model = nn.Sequential(Surrogate(conv), Surrogate(conv)) | 
|  |  | 
|  | Pipe(model) | 
|  |  | 
|  |  | 
|  | def test_forward_lockstep(setup_rpc): | 
|  | timeline = [] | 
|  |  | 
|  | class DelayedLog(nn.Module): | 
|  | def __init__(self, j, seconds): | 
|  | super().__init__() | 
|  | self.i = 0 | 
|  | self.j = j | 
|  | self.seconds = seconds | 
|  |  | 
|  | def forward(self, x): | 
|  | time.sleep(self.seconds) | 
|  |  | 
|  | timeline.append((self.i, self.j)) | 
|  | self.i += 1 | 
|  |  | 
|  | return x | 
|  |  | 
|  | model = nn.Sequential(DelayedLog(0, seconds=0), DelayedLog(1, seconds=0.1)) | 
|  | model = Pipe(model, chunks=3) | 
|  | model(torch.rand(3, 1)) | 
|  |  | 
|  | # Expected timeline: (Logs are recorded at !) | 
|  | # | 
|  | # Partition #0: 0! 1!   2! | 
|  | # Partition #1:    000! 111! 222! | 
|  | # | 
|  | assert timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)] | 
|  |  | 
|  | @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) | 
|  | @skip_if_no_cuda | 
|  | def test_multiple_inputs(checkpoint, setup_rpc): | 
|  | class Module1(nn.Module): | 
|  | def forward(self, a, b, c): | 
|  | return a + b + c, a * b * c | 
|  |  | 
|  | class Module2(nn.Module): | 
|  | def forward(self, a, b): | 
|  | return a + b | 
|  |  | 
|  | model = Pipe(nn.Sequential(Module1().cuda(0), Module2().cuda(0)), chunks=2, checkpoint=checkpoint) | 
|  | t = torch.rand(10) | 
|  | res = model(t, t, t).local_value() | 
|  | assert torch.equal(res, (t + t + t) + (t * t * t)) | 
|  |  | 
|  | @skip_if_no_cuda | 
|  | @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs") | 
|  | def test_inputs_wrong_device(setup_rpc): | 
|  | class Module1(nn.Module): | 
|  | def __init__(self): | 
|  | super().__init__() | 
|  | self.param = torch.nn.Parameter(torch.rand(5)) | 
|  |  | 
|  | def forward(self, a, b): | 
|  | return a + b + self.param, b | 
|  |  | 
|  | # Start inputs on wrong device and ensure Pipe moves them correctly. | 
|  | a = torch.rand(10).cuda(1) | 
|  | b = torch.rand(10).cuda(1) | 
|  | model = Pipe(nn.Sequential(Module1().cuda(0), Module1().cuda(1)), chunks=2) | 
|  | with pytest.raises(ValueError, match='All inputs should be on the same device as the first partition'): | 
|  | model(a, b) | 
|  |  | 
|  | @skip_if_no_cuda | 
|  | @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs") | 
|  | def test_with_device_wrapper(setup_rpc): | 
|  | fc1 = nn.Linear(16, 8).cuda(0) | 
|  | fc2 = nn.Linear(8, 4).cuda(1) | 
|  | dropout = nn.Dropout() | 
|  |  | 
|  | model = nn.Sequential(fc1, fc2, WithDevice(dropout, 'cuda:1')) | 
|  | model = Pipe(model, chunks=8) | 
|  | assert torch.device('cuda:1') == model(torch.rand(16, 16).cuda(0)).local_value().device | 
|  | assert [torch.device('cuda:0'), torch.device('cuda:1')] == model.devices | 
|  |  | 
|  | model = nn.Sequential(fc1, WithDevice(dropout, 'cuda:1')) | 
|  | model = Pipe(model, chunks=8) | 
|  | assert torch.device('cuda:1') == model(torch.rand(16, 16).cuda(0)).local_value().device | 
|  | assert [torch.device('cuda:0'), torch.device('cuda:1')] == model.devices | 
|  |  | 
|  | model = nn.Sequential(fc1, WithDevice(fc2, 'cuda:0')) | 
|  | model = Pipe(model, chunks=8) | 
|  | assert torch.device('cuda:0') == model(torch.rand(16, 16).cuda(0)).local_value().device | 
|  | assert [torch.device('cuda:0')] == model.devices | 
|  | assert torch.device('cuda:0') == fc2.weight.device | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | run_tests() |