blob: 5fe984c2bd6a73c5d1af64aaee0da9446b8d3603 [file] [log] [blame]
# Owner(s): ["module: nn"]
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
skipIfTorchDynamo,
)
import torch
import torch.nn as nn
from functools import partial
from typing import List, Tuple
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.seq1 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)])
self.seq2 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.seq2(self.seq1(x))
class ToyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.net1 = Net()
self.net2 = Net()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net2(self.net1(x))
def forward_hook(
self: TestCase,
fired_hooks: List[int],
expected_module: nn.Module,
hook_id: int,
module: nn.Module,
inp: Tuple[torch.Tensor],
out: torch.Tensor,
) -> None:
fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module))
self.assertEqual(len(inp), 1)
def forward_pre_hook(
self: TestCase,
fired_hooks: List[int],
expected_module: nn.Module,
hook_id: int,
module: nn.Module,
inp: Tuple[torch.Tensor],
) -> None:
fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module))
self.assertEqual(len(inp), 1)
def full_backward_hook(
self: TestCase,
fired_hooks: List[int],
expected_module: nn.Module,
hook_id: int,
module: nn.Module,
grad_input: Tuple[torch.Tensor],
grad_output: Tuple[torch.Tensor],
) -> None:
fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module))
self.assertEqual(len(grad_input), 1)
self.assertEqual(len(grad_output), 1)
def full_backward_pre_hook(
self: TestCase,
fired_hooks: List[int],
expected_module: nn.Module,
hook_id: int,
module: nn.Module,
grad_input: Tuple[torch.Tensor],
) -> None:
fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module))
self.assertEqual(len(grad_input), 1)
class TestModuleHooks(TestCase):
@skipIfTorchDynamo("Dynamo does not yet capture hooks")
def test_forward_hooks(self):
fired_hooks: List[int] = []
model = ToyModel()
x = torch.randn(10, 10)
hook = partial(forward_hook, self, fired_hooks, model.net1.seq2)
model.net1.seq2.register_forward_hook(partial(hook, 0))
model.net1.seq2.register_forward_hook(partial(hook, 1), prepend=True)
model.net1.seq2.register_forward_hook(partial(hook, 2))
model.net1.seq2.register_forward_hook(partial(hook, 3))
model.net1.seq2.register_forward_hook(partial(hook, 4), prepend=True)
expected = [4, 1, 0, 2, 3]
self.assertEqual(fired_hooks, [])
out = model(x)
self.assertEqual(fired_hooks, expected)
out.sum().backward()
self.assertEqual(fired_hooks, expected)
model(x).sum().backward()
self.assertEqual(fired_hooks, expected + expected)
@skipIfTorchDynamo("Dynamo does not yet capture hooks")
def test_forward_pre_hooks(self):
fired_hooks: List[int] = []
model = ToyModel()
x = torch.randn(10, 10)
hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1)
model.net2.seq1.register_forward_pre_hook(partial(hook, 0), prepend=True)
model.net2.seq1.register_forward_pre_hook(partial(hook, 1))
model.net2.seq1.register_forward_pre_hook(partial(hook, 2))
model.net2.seq1.register_forward_pre_hook(partial(hook, 3))
model.net2.seq1.register_forward_pre_hook(partial(hook, 4), prepend=True)
expected = [4, 0, 1, 2, 3]
self.assertEqual(fired_hooks, [])
out = model(x)
self.assertEqual(fired_hooks, expected)
out.sum().backward()
self.assertEqual(fired_hooks, expected)
model(x).sum().backward()
self.assertEqual(fired_hooks, expected + expected)
@skipIfTorchDynamo("Dynamo does not yet capture hooks")
def test_full_backward_hooks(self):
fired_hooks: List[int] = []
model = ToyModel()
x = torch.randn(10, 10)
hook = partial(full_backward_hook, self, fired_hooks, model.net1)
model.net1.register_full_backward_hook(partial(hook, 0))
model.net1.register_full_backward_hook(partial(hook, 1))
model.net1.register_full_backward_hook(partial(hook, 2))
model.net1.register_full_backward_hook(partial(hook, 3), prepend=True)
model.net1.register_full_backward_hook(partial(hook, 4), prepend=True)
expected = [4, 3, 0, 1, 2]
self.assertEqual(fired_hooks, [])
out = model(x)
self.assertEqual(fired_hooks, [])
out.sum().backward()
self.assertEqual(fired_hooks, expected)
model(x).sum().backward()
self.assertEqual(fired_hooks, expected + expected)
@skipIfTorchDynamo("Dynamo does not yet capture hooks")
def test_full_backward_pre_hooks(self):
fired_hooks: List[int] = []
model = ToyModel()
x = torch.randn(10, 10)
hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1)
model.net1.register_full_backward_pre_hook(partial(hook, 0), prepend=True)
model.net1.register_full_backward_pre_hook(partial(hook, 1), prepend=True)
model.net1.register_full_backward_pre_hook(partial(hook, 2))
model.net1.register_full_backward_pre_hook(partial(hook, 3))
model.net1.register_full_backward_pre_hook(partial(hook, 4))
expected = [1, 0, 2, 3, 4]
self.assertEqual(fired_hooks, [])
out = model(x)
self.assertEqual(fired_hooks, [])
out.sum().backward()
self.assertEqual(fired_hooks, expected)
model(x).sum().backward()
self.assertEqual(fired_hooks, expected + expected)
@skipIfTorchDynamo("Dynamo does not yet capture hooks")
def test_mixed_hooks(self):
fired_hooks: List[int] = []
model = ToyModel()
x = torch.randn(10, 10)
model.register_forward_pre_hook(partial(forward_pre_hook, self, fired_hooks, model, 0))
model.register_forward_hook(partial(forward_hook, self, fired_hooks, model, 1))
model.register_full_backward_pre_hook(partial(full_backward_pre_hook, self, fired_hooks, model, 2))
model.register_full_backward_hook(partial(full_backward_hook, self, fired_hooks, model, 3))
self.assertEqual(fired_hooks, [])
out = model(x)
self.assertEqual(fired_hooks, [0, 1])
out.sum().backward()
self.assertEqual(fired_hooks, [0, 1, 2, 3])
model(x).sum().backward()
self.assertEqual(fired_hooks, [0, 1, 2, 3, 0, 1, 2, 3])
if __name__ == "__main__":
run_tests()