blob: d510af6d7b2b0b61f53df3ea736c9f9c96a9ea79 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
skipIfTorchDynamo,
)
import torch
import torch.nn as nn
from torch.distributed._composable import contract
from copy import deepcopy
from typing import Tuple
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.seq1 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)])
self.seq2 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)])
self.p = nn.Parameter(torch.randn(10, 10), requires_grad=True)
self.b = torch.zeros(1) # buffer
def forward(self, x, y):
with torch.no_grad():
self.b += x.sum() + y.sum()
return self.p + self.seq1(x) + self.seq2(y)
class TestContract(TestCase):
@skipIfTorchDynamo("Dynamo does not yet capture module hooks")
def test_add_hooks(self):
def forward_pre_hook(
module: nn.Module, inp: Tuple[torch.Tensor]
) -> Tuple[torch.Tensor]:
return inp
def forward_hook(
module: nn.Module, inp: Tuple[torch.Tensor], out: torch.Tensor
) -> torch.Tensor:
return out
def backward_pre_hook(
module: nn.Module, grad_output: torch.Tensor
) -> torch.Tensor:
return grad_output
def backward_hook(
module: nn.Module,
grad_input: Tuple[torch.Tensor],
grad_output: torch.Tensor,
) -> Tuple[torch.Tensor]:
return grad_input
@contract
def noop_api(module: nn.Module) -> nn.Module:
module.register_forward_pre_hook(forward_pre_hook)
module.register_forward_hook(forward_hook)
module.register_full_backward_pre_hook(backward_pre_hook)
module.register_full_backward_hook(backward_hook)
return module
model = ToyModel()
model_with_hooks = deepcopy(model)
noop_api(model.seq1)
noop_api(model.seq2)
x, y = torch.randn(10, 10), torch.randn(10, 10)
model(x, y).sum().backward()
model_with_hooks(x, y).sum().backward()
for p1, p2 in zip(model.parameters(), model_with_hooks.parameters()):
self.assertEqual(p1, p2)
@skipIfTorchDynamo("Dynamo does not yet capture module hooks")
def test_modify_fqn(self):
class ModelWrapper(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, x):
return self.module(x)
@contract
def wrap_module(module: nn.Module) -> nn.Module:
return ModelWrapper(module)
model = ToyModel()
with self.assertRaisesRegex(RuntimeError, "cannot modify FQNs"):
wrap_module(model.seq1)
@skipIfTorchDynamo("Dynamo does not yet capture module hooks")
def test_state(self):
def check_and_update_state_hook(
module: nn.Module, inp: Tuple[torch.Tensor]
) -> Tuple[torch.Tensor]:
self.assertEqual(api.state(module).dummy_state, 7)
api.state(module).dummy_state = 8
return inp
# FIXME: circular reference looks a bit weird. Shall we make .state a
# top-level API instead attached to contract API?
@contract
def api(module: nn.Module) -> nn.Module:
api.state(module).dummy_state = 7
module.register_forward_pre_hook(check_and_update_state_hook)
return module
model = ToyModel()
api(model.seq1)
self.assertEqual(api.state(model.seq1).dummy_state, 7)
model(torch.zeros(10, 10), torch.zeros(10, 10))
self.assertEqual(api.state(model.seq1).dummy_state, 8)
if __name__ == "__main__":
run_tests()