blob: dc4bead782429b8cceca8c5f9b87596a521ff093 [file] [log] [blame]
# Owner(s): ["module: nn"]
import gc
import math
import pickle
import unittest
import warnings
import weakref
from collections import namedtuple, OrderedDict
from copy import deepcopy
from functools import partial
from tempfile import NamedTemporaryFile
from typing import Any, Dict, List, Tuple
import torch
import torch.nn as nn
from torch.testing._internal.common_nn import _create_basic_net, NNTestCase
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_WINDOWS,
parametrize as parametrize_test,
run_tests,
skipIfTorchDynamo,
swap,
TestCase,
)
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.seq1 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)])
self.seq2 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.seq2(self.seq1(x))
ToyNamedTuple = namedtuple("ToyNamedTuple", "content")
class ToyModel(nn.Module):
def __init__(self, with_named_tuple=False) -> None:
super().__init__()
self.net1 = Net()
self.net2 = Net()
self.with_named_tuple = with_named_tuple
def forward(self, x: torch.Tensor) -> torch.Tensor:
res = self.net2(self.net1(x))
if self.with_named_tuple:
return ToyNamedTuple(res)
else:
return (res,)
def forward_hook(
self: TestCase,
fired_hooks: List[int],
expected_module: nn.Module,
hook_id: int,
module: nn.Module,
inp: Tuple[torch.Tensor],
out: torch.Tensor,
) -> None:
fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module))
self.assertEqual(len(inp), 1)
def forward_pre_hook(
self: TestCase,
fired_hooks: List[int],
expected_module: nn.Module,
hook_id: int,
module: nn.Module,
inp: Tuple[torch.Tensor],
) -> None:
fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module))
self.assertEqual(len(inp), 1)
def full_backward_hook(
self: TestCase,
fired_hooks: List[int],
expected_module: nn.Module,
hook_id: int,
module: nn.Module,
grad_input: Tuple[torch.Tensor],
grad_output: Tuple[torch.Tensor],
) -> None:
fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module))
self.assertEqual(len(grad_input), 1)
self.assertEqual(len(grad_output), 1)
def full_backward_pre_hook(
self: TestCase,
fired_hooks: List[int],
expected_module: nn.Module,
hook_id: int,
module: nn.Module,
grad_input: Tuple[torch.Tensor],
) -> None:
fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module))
self.assertEqual(len(grad_input), 1)
class KwargModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.net1 = Net()
self.net2 = Net()
def forward(self, x: torch.Tensor, bias: torch.Tensor = None) -> torch.Tensor:
if bias is not None:
x = x + bias
return x
def internal_forward_hook(
self,
module: nn.Module,
args: Tuple[torch.Tensor],
kwargs: Dict[str, Any],
out: torch.Tensor,
):
return out + kwargs["bias"]
class FailsInForwardModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.net1 = Net()
def forward(self, x: torch.Tensor, fail: bool = True) -> torch.Tensor:
if fail:
raise RuntimeError("failing in forward")
return self.net1(x)
def kwarg_forward_pre_hook(
self: TestCase,
fired_hooks: List[int],
expected_module: nn.Module,
hook_id: int,
module: nn.Module,
args: Tuple[torch.Tensor],
kwargs: Dict[str, Any],
) -> Tuple[Any, Any]:
fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module))
self.assertEqual(len(args), 1)
kwargs["bias"] = 2 * kwargs["bias"]
return args, kwargs
def kwarg_forward_hook(
self: TestCase,
fired_hooks: List[int],
expected_module: nn.Module,
hook_id: int,
module: nn.Module,
args: Tuple[torch.Tensor],
kwargs: Dict[str, Any],
out: torch.Tensor,
) -> Any:
fired_hooks.append(hook_id)
self.assertEqual(id(module), id(expected_module))
self.assertEqual(len(args), 1)
out = out + kwargs["bias"]
return out
class DummyContextManager:
def __init__(self, inp):
self.input = inp
def __enter__(self, *args, **kwargs):
self.input.append(2)
def __exit__(self, *args, **kwargs):
self.input.append(-1)
class TestModuleHooks(TestCase):
@parametrize_test("named_tuple", (True, False))
def test_forward_hooks(self, named_tuple):
fired_hooks: List[int] = []
model = ToyModel(named_tuple)
x = torch.randn(10, 10)
hook = partial(forward_hook, self, fired_hooks, model.net1.seq2)
model.net1.seq2.register_forward_hook(partial(hook, 0))
model.net1.seq2.register_forward_hook(partial(hook, 1), prepend=True)
model.net1.seq2.register_forward_hook(partial(hook, 2))
model.net1.seq2.register_forward_hook(partial(hook, 3))
model.net1.seq2.register_forward_hook(partial(hook, 4), prepend=True)
expected = [4, 1, 0, 2, 3]
self.assertEqual(fired_hooks, [])
out = model(x)
self.assertEqual(fired_hooks, expected)
self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple)
out[0].sum().backward()
self.assertEqual(fired_hooks, expected)
model(x)[0].sum().backward()
self.assertEqual(fired_hooks, expected + expected)
@parametrize_test("named_tuple", (True, False))
def test_forward_pre_hooks(self, named_tuple):
fired_hooks: List[int] = []
model = ToyModel(named_tuple)
x = torch.randn(10, 10)
hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1)
model.net2.seq1.register_forward_pre_hook(partial(hook, 0), prepend=True)
model.net2.seq1.register_forward_pre_hook(partial(hook, 1))
model.net2.seq1.register_forward_pre_hook(partial(hook, 2))
model.net2.seq1.register_forward_pre_hook(partial(hook, 3))
model.net2.seq1.register_forward_pre_hook(partial(hook, 4), prepend=True)
expected = [4, 0, 1, 2, 3]
self.assertEqual(fired_hooks, [])
out = model(x)
self.assertEqual(fired_hooks, expected)
self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple)
out[0].sum().backward()
self.assertEqual(fired_hooks, expected)
model(x)[0].sum().backward()
self.assertEqual(fired_hooks, expected + expected)
@parametrize_test("named_tuple", (True, False))
def test_full_backward_hooks(self, named_tuple):
fired_hooks: List[int] = []
model = ToyModel(named_tuple)
x = torch.randn(10, 10)
hook = partial(full_backward_hook, self, fired_hooks, model.net1)
model.net1.register_full_backward_hook(partial(hook, 0))
model.net1.register_full_backward_hook(partial(hook, 1))
model.net1.register_full_backward_hook(partial(hook, 2))
model.net1.register_full_backward_hook(partial(hook, 3), prepend=True)
model.net1.register_full_backward_hook(partial(hook, 4), prepend=True)
expected = [4, 3, 0, 1, 2]
self.assertEqual(fired_hooks, [])
out = model(x)
self.assertEqual(fired_hooks, [])
self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple)
out[0].sum().backward()
self.assertEqual(fired_hooks, expected)
model(x)[0].sum().backward()
self.assertEqual(fired_hooks, expected + expected)
@parametrize_test("named_tuple", (True, False))
def test_full_backward_pre_hooks(self, named_tuple):
fired_hooks: List[int] = []
model = ToyModel(named_tuple)
x = torch.randn(10, 10)
hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1)
model.net1.register_full_backward_pre_hook(partial(hook, 0), prepend=True)
model.net1.register_full_backward_pre_hook(partial(hook, 1), prepend=True)
model.net1.register_full_backward_pre_hook(partial(hook, 2))
model.net1.register_full_backward_pre_hook(partial(hook, 3))
model.net1.register_full_backward_pre_hook(partial(hook, 4))
expected = [1, 0, 2, 3, 4]
self.assertEqual(fired_hooks, [])
out = model(x)
self.assertEqual(fired_hooks, [])
self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple)
out[0].sum().backward()
self.assertEqual(fired_hooks, expected)
model(x)[0].sum().backward()
self.assertEqual(fired_hooks, expected + expected)
# Backward pre hook can affect subsequent gradient computation
for rg in [True, False]:
a = torch.ones(2, requires_grad=rg)
model = nn.Linear(2, 2)
def fn(_unused_module, grad_output):
return (grad_output[0] * 0,)
model.register_full_backward_pre_hook(fn)
out = model(a)
out.sum().backward()
self.assertEqual(model.weight.grad, torch.zeros(2, 2))
if rg:
self.assertEqual(a.grad, torch.zeros_like(a))
else:
self.assertIsNone(a.grad)
@parametrize_test("named_tuple", (True, False))
def test_mixed_hooks(self, named_tuple):
fired_hooks: List[int] = []
model = ToyModel(named_tuple)
x = torch.randn(10, 10)
model.register_forward_pre_hook(
partial(forward_pre_hook, self, fired_hooks, model, 0)
)
model.register_forward_hook(partial(forward_hook, self, fired_hooks, model, 1))
model.register_full_backward_pre_hook(
partial(full_backward_pre_hook, self, fired_hooks, model, 2)
)
model.register_full_backward_hook(
partial(full_backward_hook, self, fired_hooks, model, 3)
)
self.assertEqual(fired_hooks, [])
out = model(x)
self.assertEqual(fired_hooks, [0, 1])
self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple)
out[0].sum().backward()
self.assertEqual(fired_hooks, [0, 1, 2, 3])
model(x)[0].sum().backward()
self.assertEqual(fired_hooks, [0, 1, 2, 3, 0, 1, 2, 3])
def test_kwarg_hooks(self):
# 1. test forward pre hook
fired_hooks: List[int] = []
x: torch.Tensor = torch.ones(10, 10)
bias: torch.Tensor = torch.ones(10, 10)
model = KwargModel()
model.register_forward_pre_hook(
partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0),
with_kwargs=True,
)
# forward-pre: bias' = bias * 2
# So, out = x + bias * 2
self.assertEqual(fired_hooks, [])
out = model(x, bias=bias)
self.assertEqual(fired_hooks, [0])
self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5)
# 2. test forward pre and forward hooks
fired_hooks: List[int] = []
x: torch.Tensor = torch.ones(10, 10)
bias: torch.Tensor = torch.ones(10, 10)
model = KwargModel()
model.register_forward_hook(
partial(kwarg_forward_hook, self, fired_hooks, model, 1),
with_kwargs=True,
)
model.register_forward_pre_hook(
partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0),
with_kwargs=True,
)
# forward-pre: bias' = bias * 2
# forward: out = x + bias'
# forward-post: out = out + bias'
# So, out = x + bias * 4
self.assertEqual(fired_hooks, [])
out = model(x, bias=bias)
self.assertEqual(fired_hooks, [0, 1])
self.assertEqual(out, x + 4 * bias, rtol=0, atol=1e-5)
# 3. test nn.Module member method as forward-post hook
x: torch.Tensor = torch.ones(10, 10)
bias: torch.Tensor = torch.ones(10, 10)
model = KwargModel()
model.register_forward_hook(model.internal_forward_hook, with_kwargs=True)
# forward: out = x + bias
# forward-post: out = out + bias
# So, out = x + bias * 2
out = model(x, bias=bias)
self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5)
def test_remove_kwarg_hooks(self):
# test forward pre and forward hooks
fired_hooks: List[int] = []
x: torch.Tensor = torch.ones(10, 10)
bias: torch.Tensor = torch.ones(10, 10)
model = KwargModel()
forward_hook_handle = model.register_forward_hook(
partial(kwarg_forward_hook, self, fired_hooks, model, 1),
with_kwargs=True,
)
forward_pre_hook_handle = model.register_forward_pre_hook(
partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0),
with_kwargs=True,
)
# forward-pre: bias' = bias * 2
# forward: out = x + bias'
# forward-post: out = out + bias'
# So, out = x + bias * 4
self.assertEqual(fired_hooks, [])
out = model(x, bias=bias)
self.assertEqual(fired_hooks, [0, 1])
self.assertEqual(out, x + 4 * bias, rtol=0, atol=1e-5)
# forward-pre: bias' = bias * 2
# forward: out = x + bias'
# So, out = x + bias * 2
forward_hook_handle.remove()
out = model(x, bias=bias)
self.assertEqual(fired_hooks, [0, 1, 0])
self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5)
self.assertFalse(forward_hook_handle.id in model._forward_hooks_with_kwargs)
# forward: out = x + bias
# So, out = x + bias
forward_pre_hook_handle.remove()
out = model(x, bias=bias)
self.assertEqual(fired_hooks, [0, 1, 0])
self.assertEqual(out, x + bias, rtol=0, atol=1e-5)
self.assertFalse(
forward_pre_hook_handle.id in model._forward_pre_hooks_with_kwargs
)
def test_always_called_forward_hooks(self):
x: torch.Tensor = torch.ones(10, 10)
model = FailsInForwardModel()
stack = []
ctx = None
def setup_context():
nonlocal ctx
ctx = DummyContextManager(stack)
def ctx_setup_hook(m, i):
setup_context()
ctx.__enter__()
def ctx_setup_failure_hook(m, i):
setup_context()
ctx.__enter__()
raise RuntimeError("failing in ctx setup")
def ctx_shutdown_hook(m, i, o):
ctx.__exit__()
def ctx_shutdown_failure_hook(m, i, o):
ctx.__exit__()
raise RuntimeError("failing in ctx shutdown")
def throw_hook(m, i, o):
raise RuntimeError("failing in throw")
forward_pre_hook_handle = model.register_forward_pre_hook(ctx_setup_hook)
forward_hook_handle = model.register_forward_hook(
ctx_shutdown_hook, always_call=True
)
self.assertTrue(len(model._forward_hooks_always_called) == 1)
# make sure always_called forward hook runs when model.forward raises RuntimeError
with self.assertRaisesRegex(RuntimeError, "failing in forward"):
model(x)
self.assertEqual(stack, [2, -1])
# make sure that always_called forward hook does not run twice if there is no error
model(x, fail=False)
self.assertEqual(stack, [2, -1, 2, -1])
# make sure always_called forward hook runs when forward pre hook raises RuntimeError
forward_pre_hook_handle.remove()
model.register_forward_pre_hook(ctx_setup_failure_hook)
with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"):
model(x, fail=False)
self.assertEqual(stack, [2, -1, 2, -1, 2, -1])
# make sure always_called hook runs when another always_called forward hook raises an error
forward_hook_handle2 = model.register_forward_hook(
throw_hook, prepend=True, always_call=True
)
# error raised should not be error of the forced hook
with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"):
model(x, fail=False)
self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1])
# make sure that always called forward hooks are properly removed
forward_hook_handle.remove()
forward_hook_handle2.remove()
self.assertTrue(len(model._forward_hooks_always_called) == 0)
# make sure that always called forward hook is not run twice if it fails while running
forward_hook_handle3 = model.register_forward_hook(
ctx_shutdown_failure_hook, always_call=True
)
with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"):
model(x, fail=False)
self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1, 2, -1])
forward_hook_handle3.remove()
global_forward_hook_handle = nn.modules.module.register_module_forward_hook(
ctx_shutdown_hook, always_call=True
)
self.assertTrue(len(nn.modules.module._global_forward_hooks_always_called) == 1)
# make sure global forward hook runs when forward pre hook raises RuntimeError
with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"):
model(x, fail=False)
self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1, 2, -1, 2, -1])
# make sure forced global forward hook is properly removed
global_forward_hook_handle.remove()
self.assertTrue(len(nn.modules.module._global_forward_hooks_always_called) == 0)
with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"):
model(x)
self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1, 2, -1, 2, -1, 2])
def test_bw_hook_warning_for_non_tensor_or_tuple(self):
# Test to verify that backward hook raises warning
# if result is not a Tensor or tuple of Tensors.
counter = {"forward": 0, "backward": 0}
def fw_pre_hook(module: nn.Module, _inputs):
counter["forward"] += 1
def fw_hook(module: nn.Module, _inputs, _outputs):
counter["forward"] += 1
def bw_hook(module: nn.Module, _inputs, _outputs):
counter["backward"] += 1
class TestModule(nn.Module):
def forward(self, dict):
inp = dict["x"]
x = torch.nn.functional.softmax(inp, dim=0)
return {"x": x}
x = torch.ones(2, requires_grad=True)
model = TestModule()
model.register_forward_pre_hook(fw_pre_hook)
model.register_forward_hook(fw_hook)
model.register_full_backward_pre_hook(bw_hook)
model.register_full_backward_hook(bw_hook)
with warnings.catch_warnings(record=True) as w:
y = model({"x": x})["x"]
loss = y.sum()
loss.backward()
self.assertEqual(counter["forward"], 2)
self.assertEqual(counter["backward"], 0)
self.assertEqual(len(w), 1)
self.assertTrue("should be a Tensor or a tuple of Tensors" in str(w[0].message))
def _hook_to_pickle(*args, **kwargs):
pass
class TestStateDictHooks(TestCase):
@swap([True, False])
def test_load_state_dict_pre_hook(self):
m = nn.Linear(10, 10)
m_state_dict = m.state_dict()
m_load = nn.Linear(10, 10)
hook_called = 0
def hook_without_module(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
self.assertEqual(m_state_dict, state_dict)
nonlocal hook_called
hook_called += 1
def hook_with_module(
module,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
self.assertEqual(m_state_dict, state_dict)
self.assertTrue(m_load is module)
nonlocal hook_called
hook_called += 1
hook_called = 0
m_load._register_load_state_dict_pre_hook(hook_without_module)
m_load.load_state_dict(m_state_dict)
self.assertEqual(1, hook_called)
hook_called = 0
m_load._register_load_state_dict_pre_hook(hook_with_module, True)
m_load.load_state_dict(m_state_dict)
self.assertEqual(2, hook_called)
def test_no_extra_ref_to_module(self):
try:
gc.disable()
m = nn.Linear(10, 10)
m._register_load_state_dict_pre_hook(_hook_to_pickle, True)
weak_m = weakref.ref(m)
del m
self.assertEqual(weak_m(), None)
finally:
gc.enable()
def test_pickled_hook(self):
m = nn.Linear(10, 10)
m._register_load_state_dict_pre_hook(_hook_to_pickle, True)
pickle.loads(pickle.dumps(m))
@swap([True, False])
def test_load_state_dict_module_pre_hook(self):
hook_called = 0
# Test with module instance method as hook
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.foo = torch.nn.Parameter(torch.rand(10))
def my_pre_load_hook(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
assert [] == error_msgs
assert [] == unexpected_keys
assert [] == missing_keys
assert strict
nonlocal hook_called
hook_called += 1
def my_pre_load_hook_with_module(
self,
module,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
assert [] == error_msgs
assert [] == unexpected_keys
assert [] == missing_keys
assert strict
assert self is module
nonlocal hook_called
hook_called += 1
# Test that hooks registered on a submodule are also called
# appropriately, i.e. with the submodule as module argument in
# my_pre_load_hook_with_module.
class MyModuleContainer(nn.Module):
def __init__(self, mod):
super().__init__()
self.mod = mod
for ctor in [MyModuleContainer, lambda x: x]:
m = ctor(MyModule())
state_dict = m.state_dict()
if isinstance(m, MyModuleContainer):
mod = m.mod
else:
mod = m
hook_called = 0
mod._register_load_state_dict_pre_hook(mod.my_pre_load_hook)
m.load_state_dict(state_dict)
self.assertEqual(1, hook_called)
hook_called = 0
mod._register_load_state_dict_pre_hook(
mod.my_pre_load_hook_with_module, True
)
m.load_state_dict(state_dict)
self.assertEqual(2, hook_called)
@swap([True, False])
def test_load_state_dict_post_hook(self):
hook_called = 0
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.foo = torch.nn.Parameter(torch.rand(10))
def my_post_load_hook(self, module, incompatible_keys):
assert module is self
nonlocal hook_called
incompatible_keys.missing_keys.append("foo")
incompatible_keys.unexpected_keys.append("bar")
hook_called += 1
nested = MyModule()
wrapped = nn.ModuleList([nested])
handle = nested.register_load_state_dict_post_hook(
nested.my_post_load_hook,
)
# Hook must be called even if it is wrapped
ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False)
self.assertEqual(hook_called, 1)
# Ensure that the hook modified missing_keys and unexpected_keys
missing = ret.missing_keys
unexpected = ret.unexpected_keys
self.assertEqual(missing, ["foo"])
self.assertEqual(unexpected, ["bar"])
# When called with strict=True, the error raised should mention the
# missing and unexpected keys the hook added.
with self.assertRaisesRegex(RuntimeError, "foo.*\n.*bar"):
wrapped.load_state_dict(wrapped.state_dict(), strict=True)
self.assertEqual(hook_called, 2)
# Removing the hook via handle.remove() should cause it not to
# fire anymore.
handle.remove()
# Hook did not run so it should not have added any keys
ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False)
self.assertEqual(ret.missing_keys, [])
self.assertEqual(ret.unexpected_keys, [])
# hook_called should not have been incremented
self.assertEqual(hook_called, 2)
def load_hook_clear_incompatible(module, incompatible_keys):
incompatible_keys.missing_keys.clear()
incompatible_keys.unexpected_keys.clear()
nested.register_load_state_dict_post_hook(load_hook_clear_incompatible)
state_dict = wrapped.state_dict()
state_dict["extra"] = torch.ones(1)
# load state_dict with strict=True should not throw.
ret = wrapped.load_state_dict(state_dict, strict=True)
# explicitly ensure that the post hook clearned out incompatible_keys
self.assertEqual([], ret.missing_keys)
self.assertEqual([], ret.unexpected_keys)
@unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows")
@swap([True, False])
def test_load_state_dict_post_hook_backward_compatibility(self):
def my_post_load_hook(mod, _):
nonlocal called
called = True
for m in [nn.Softmin(10), nn.Softmax(10), nn.LogSoftmax(10)]:
called = False
sd = deepcopy(m.state_dict())
self.assertTrue(hasattr(m, "_load_state_dict_post_hooks"))
# Simulate an older model that did not have this attr
delattr(m, "_load_state_dict_post_hooks")
# Save and load, and ensure that load_state_dict works (without proper
# BC we would run into errors because this attribute would be expected).
# In particular, Softmax runs into the issue described here:
# https://github.com/pytorch/pytorch/issues/77280
with NamedTemporaryFile() as f:
# Note that torch.save / torch.load is not recommended to save/load
# modules.
torch.save(m, f.name)
m = torch.load(f.name)
m.load_state_dict(sd)
self.assertFalse(called)
# Ensure hooks can be registered and called.
m.register_load_state_dict_post_hook(my_post_load_hook)
m.load_state_dict(sd)
self.assertTrue(called)
def _test_register_state_dict_pre_hook(self, model, submodule):
_state_dict_prefix = "foo."
state_dict_pre_hook_count = 0
keep_var_setting = False
def my_state_dict_pre_hook(module, prefix, keep_vars):
self.assertEqual(keep_vars, keep_var_setting)
nonlocal state_dict_pre_hook_count
state_dict_pre_hook_count += 1
self.assertTrue(prefix.startswith(_state_dict_prefix))
model.register_state_dict_pre_hook(my_state_dict_pre_hook)
# Test to ensure submodules run the hook as well.
submodule.register_state_dict_pre_hook(my_state_dict_pre_hook)
def check_results(model):
nonlocal state_dict_pre_hook_count, keep_var_setting
for keep_var_setting in [True, False]:
_ = model.state_dict(
prefix=_state_dict_prefix, keep_vars=keep_var_setting
)
self.assertEqual(2, state_dict_pre_hook_count)
state_dict_pre_hook_count = 0
# Test state dict works as expected after model construction
check_results(model)
# Test state dict works as expected after forward
model(torch.ones(10, 3))
check_results(model)
def test_register_state_dict_pre_hook(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Sequential(
nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)
)
def forward(self, x):
return self.a(x)
mod = MyModule()
self._test_register_state_dict_pre_hook(mod, mod.a)
def test_register_state_dict_pre_hook_lazy_module(self):
class MyLazyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.LazyLinear(8)
self.layer2 = nn.LazyLinear(5)
def forward(self, x):
return self.layer2(self.layer1(x))
mod = MyLazyModule()
self._test_register_state_dict_pre_hook(mod, mod.layer1)
@unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows")
def test_register_state_dict_pre_hook_backward_compat(self):
called = False
def my_state_dict_pre_hook(*args, **kwargs):
nonlocal called
called = True
m = nn.Linear(1, 1)
self.assertTrue(hasattr(m, "_state_dict_pre_hooks"))
delattr(m, "_state_dict_pre_hooks")
# Save and load, ensure we can still call state_dict
# without running into issues.
with NamedTemporaryFile() as f:
# Note that torch.save / torch.load is not recommended
# to save / load modules.
torch.save(m, f.name)
m = torch.load(f.name)
# Ensure we can run state_dict without issues
_ = m.state_dict()
self.assertFalse(called)
m.register_state_dict_pre_hook(my_state_dict_pre_hook)
_ = m.state_dict()
self.assertTrue(called)
class TestModuleGlobalHooks(TestCase):
def tearDown(self):
nn.modules.module._global_backward_hooks = OrderedDict()
nn.modules.module._global_forward_hooks = OrderedDict()
nn.modules.module._global_forward_pre_hooks = OrderedDict()
@skipIfTorchDynamo("TorchDynamo does not work well with hooks")
def test_module_global_hooks(self):
module = nn.Sigmoid
module_1 = module()
module_2 = module()
module_3 = module()
input = torch.ones(5, 5, requires_grad=True)
counter = {"forwards": 0, "backwards": 0}
def fw_hook(inc, h_module, input, output):
self.assertIsInstance(input, tuple)
self.assertTrue(isinstance(output, torch.Tensor))
self.assertTrue(isinstance(h_module, module))
self.assertEqual(input[0], torch.ones(5, 5))
self.assertEqual(output, torch.empty(5, 5).fill_(1 / (1 + 1 / math.e)))
counter["forwards"] += inc
def bw_hook(inc, h_module, grad_input, grad_output):
self.assertIsInstance(grad_input, tuple)
self.assertIsInstance(grad_output, tuple)
self.assertTrue(isinstance(h_module, module))
self.assertEqual(grad_output[0], torch.ones(5, 5) * 2)
counter["backwards"] += inc
test_fwd = nn.modules.module.register_module_forward_hook(
lambda *args: fw_hook(1, *args)
)
module_1(input)
module_2(input)
module_3(input)
self.assertEqual(counter["forwards"], 3)
self.assertEqual(counter["backwards"], 0)
test_bwd = nn.modules.module.register_module_backward_hook(
lambda *args: bw_hook(1, *args)
)
output_1 = module_1(input)
output_2 = module_2(input)
output_3 = module_3(input)
self.assertEqual(counter["forwards"], 6)
self.assertEqual(counter["backwards"], 0)
output_1.backward(torch.ones(5, 5) * 2, retain_graph=True)
output_2.backward(torch.ones(5, 5) * 2, retain_graph=False)
output_3.backward(torch.ones(5, 5) * 2, retain_graph=False)
self.assertEqual(counter["forwards"], 6)
self.assertEqual(counter["backwards"], 3)
output_1.backward(torch.ones(5, 5) * 2, retain_graph=True)
self.assertEqual(counter["forwards"], 6)
self.assertEqual(counter["backwards"], 4)
test2_fwd = nn.modules.module.register_module_forward_hook(
lambda *args: fw_hook(2, *args)
)
output = module_1(input)
output = module_2(input)
output = module_3(input)
self.assertEqual(counter["forwards"], 15)
self.assertEqual(counter["backwards"], 4)
test2_bwd = nn.modules.module.register_module_backward_hook(
lambda *args: bw_hook(2, *args)
)
module_1(input).backward(torch.ones(5, 5) * 2)
self.assertEqual(counter["forwards"], 18)
self.assertEqual(counter["backwards"], 7)
test2_bwd.remove()
module_2(input).backward(torch.ones(5, 5) * 2)
self.assertEqual(counter["forwards"], 21)
self.assertEqual(counter["backwards"], 8)
test2_fwd.remove()
module_3(input).backward(torch.ones(5, 5) * 2)
self.assertEqual(counter["forwards"], 22)
self.assertEqual(counter["backwards"], 9)
test_fwd.remove()
test_bwd.remove()
def test_module_global_hook_invalid_outputs(self):
module = nn.Sigmoid()
input = torch.randn(5, 5, requires_grad=True)
def bw_fail1(self, grad_input, grad_output):
return grad_input[:-1]
def bw_fail2(self, grad_input, grad_output):
return grad_input + (torch.randn(2, 2),)
with nn.modules.module.register_module_backward_hook(bw_fail1):
with self.assertRaisesRegex(RuntimeError, "got 0, but expected 1"):
module(input).sum().backward()
with nn.modules.module.register_module_backward_hook(bw_fail2):
with self.assertRaisesRegex(RuntimeError, "got 2, but expected 1"):
module(input).sum().backward()
def test_module_backward_global_hook_writeable(self):
module = nn.Sigmoid()
input = torch.randn(5, 5, requires_grad=True)
sig_x = torch.sigmoid(input)
def bw_hook(module, grad_input, grad_output):
for grad in grad_input:
self.assertTrue(isinstance(grad, torch.Tensor))
for grad in grad_output:
self.assertTrue(isinstance(grad, torch.Tensor))
return tuple(gi * 2 for gi in grad_input)
nn.modules.module.register_module_backward_hook(bw_hook)
module(input).backward(torch.ones(5, 5))
expected_grad = sig_x * (1 - sig_x) * 2
self.assertEqual(input.grad, expected_grad)
@skipIfTorchDynamo("TorchDynamo does not work well with hooks")
def test_module_global_forward_preforward_hook_writeable(self):
module = nn.Sigmoid()
input = torch.randn(5, 5, requires_grad=True)
sig_x = torch.sigmoid(input)
def forward_pre_hook(m, input):
return torch.nn.functional.relu(input[0])
def forward_hook(m, input, output):
return -output
nn.modules.module.register_module_forward_pre_hook(forward_pre_hook)
nn.modules.module.register_module_forward_hook(forward_hook)
output = module(input)
expected_res = -torch.sigmoid(torch.nn.functional.relu(input))
self.assertEqual(output, expected_res)
output.backward(torch.ones(5, 5) * 2, retain_graph=True)
mask = input > 0
expected_grad = -sig_x * (1 - sig_x) * 2 * mask
self.assertEqual(input.grad, expected_grad)
def test_module_forward_preforward_hook_removable(self):
"""
This test is to test when multiple pre-forward hook functions can be
registered successfully and used correctly, if the handle can be removable
during the pre-forward hook function call.
"""
module = nn.Sigmoid()
def removable_hook(m, input):
nonlocal handle
handle.remove()
return input
def removable_hook_2(m, input):
nonlocal handle_2
handle_2.remove()
return input
handle = module.register_forward_pre_hook(removable_hook)
handle_2 = module.register_forward_pre_hook(removable_hook_2)
# make sure hook register is successful
self.assertEqual(len(handle.hooks_dict_ref()), 2)
self.assertEqual(len(handle_2.hooks_dict_ref()), 2)
input = torch.randn(2, 2)
output = module(input)
self.assertEqual(torch.sigmoid(input), output)
# make sure hook removal is successful
self.assertFalse(handle.id in handle.hooks_dict_ref())
self.assertFalse(handle_2.id in handle.hooks_dict_ref())
self.assertEqual(len(handle.hooks_dict_ref()), 0)
self.assertEqual(len(handle_2.hooks_dict_ref()), 0)
def test_module_forward_forward_hook_removable(self):
"""
This test is to test when multiple forward hook functions can be registered
successfully and used correctly, if the handle can be removable during the
forward hook function call.
"""
module = nn.Sigmoid()
def removable_hook(m, input, output):
nonlocal handle
handle.remove()
return output
def removable_hook_2(m, input, output):
nonlocal handle_2
handle_2.remove()
return output
handle = module.register_forward_hook(removable_hook)
handle_2 = module.register_forward_hook(removable_hook_2)
# make sure hook register is successful
self.assertEqual(len(handle.hooks_dict_ref()), 2)
self.assertEqual(len(handle_2.hooks_dict_ref()), 2)
input = torch.randn(2, 2)
output = module(input)
self.assertEqual(torch.sigmoid(input), output)
# make sure hook removal is successful
self.assertFalse(handle.id in handle.hooks_dict_ref())
self.assertFalse(handle_2.id in handle.hooks_dict_ref())
self.assertEqual(len(handle.hooks_dict_ref()), 0)
self.assertEqual(len(handle_2.hooks_dict_ref()), 0)
@skipIfTorchDynamo("TorchDynamo does not work well with hooks")
def test_global_and_local_hooks_order(self):
module = nn.Sigmoid()
global_forward_pre_called = False
local_forward_pre_called = False
global_forward_called = False
local_forward_called = False
global_backward_called = False
local_backward_called = False
def global_forward_pre_hook(m, input):
nonlocal global_forward_pre_called
self.assertTrue(not local_forward_pre_called)
global_forward_pre_called = True
return input
def local_forward_pre_hook(m, input):
nonlocal local_forward_pre_called
self.assertTrue(global_forward_pre_called)
local_forward_pre_called = True
return input
def global_forward_hook(m, input, output):
nonlocal global_forward_called
self.assertTrue(not local_forward_called)
global_forward_called = True
return output
def local_forward_hook(m, input, output):
nonlocal local_forward_called
self.assertTrue(global_forward_called)
local_forward_called = True
return output
def global_backward_hook(m, input, output):
nonlocal global_backward_called
self.assertTrue(not local_backward_called)
global_backward_called = True
return input
def local_backward_hook(m, input, output):
nonlocal local_backward_called
self.assertTrue(global_backward_called)
local_backward_called = True
return input
input = torch.randn(5, 5, requires_grad=True)
nn.modules.module.register_module_forward_pre_hook(global_forward_pre_hook)
module.register_forward_pre_hook(local_forward_pre_hook)
nn.modules.module.register_module_forward_hook(global_forward_hook)
module.register_forward_hook(local_forward_hook)
nn.modules.module.register_module_backward_hook(global_backward_hook)
module.register_backward_hook(local_backward_hook)
output = module(input)
self.assertTrue(
local_forward_called
and local_forward_pre_called
and global_forward_called
and global_forward_pre_called
)
output.backward(torch.ones(5, 5), retain_graph=True)
self.assertTrue(local_backward_called and global_backward_called)
class TestModuleHookNN(NNTestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True
def _test_hooks(self, backward_register_fn):
module = nn.Sigmoid()
input = torch.ones(5, 5, requires_grad=True)
counter = {"forwards": 0, "backwards": 0}
def fw_hook(inc, h_module, input, output):
self.assertIsInstance(input, tuple)
self.assertTrue(isinstance(output, torch.Tensor))
self.assertTrue(h_module is module)
self.assertEqual(input[0], torch.ones(5, 5))
self.assertEqual(output, torch.empty(5, 5).fill_(1 / (1 + 1 / math.e)))
counter["forwards"] += inc
def bw_hook(inc, h_module, grad_input, grad_output):
self.assertIsInstance(grad_input, tuple)
self.assertIsInstance(grad_output, tuple)
self.assertTrue(h_module is module)
self.assertEqual(grad_output[0], torch.ones(5, 5) * 2)
counter["backwards"] += inc
# backward_pre_hook expects callback with only `module` and `grad_output`
# as arguments.
def bw_pre_hook(inc, h_module, grad_output):
self.assertIsInstance(grad_output, tuple)
self.assertTrue(h_module is module)
self.assertEqual(grad_output[0], torch.ones(5, 5) * 2)
counter["backwards"] += inc
test_fwd = module.register_forward_hook(lambda *args: fw_hook(1, *args))
module(input)
module(input)
self.assertEqual(counter["forwards"], 2)
self.assertEqual(counter["backwards"], 0)
bw_hook_fn = (
bw_pre_hook
if backward_register_fn == "register_full_backward_pre_hook"
else bw_hook
)
test_bwd = getattr(module, backward_register_fn)(
lambda *args: bw_hook_fn(1, *args)
)
output = module(input)
self.assertEqual(counter["forwards"], 3)
self.assertEqual(counter["backwards"], 0)
output.backward(torch.ones(5, 5) * 2, retain_graph=True)
self.assertEqual(counter["forwards"], 3)
self.assertEqual(counter["backwards"], 1)
output.backward(torch.ones(5, 5) * 2, retain_graph=True)
self.assertEqual(counter["forwards"], 3)
self.assertEqual(counter["backwards"], 2)
test2_fwd = module.register_forward_hook(lambda *args: fw_hook(2, *args))
output = module(input)
self.assertEqual(counter["forwards"], 6)
self.assertEqual(counter["backwards"], 2)
test2_bwd = getattr(module, backward_register_fn)(
lambda *args: bw_hook_fn(2, *args)
)
module(input).backward(torch.ones(5, 5) * 2)
self.assertEqual(counter["forwards"], 9)
self.assertEqual(counter["backwards"], 5)
test2_bwd.remove()
module(input).backward(torch.ones(5, 5) * 2)
self.assertEqual(counter["forwards"], 12)
self.assertEqual(counter["backwards"], 6)
test2_fwd.remove()
module(input).backward(torch.ones(5, 5) * 2)
self.assertEqual(counter["forwards"], 13)
self.assertEqual(counter["backwards"], 7)
test_fwd.remove()
test_bwd.remove()
def test_hooks(self):
self._test_hooks("register_backward_hook")
self._test_hooks("register_full_backward_hook")
self._test_hooks("register_full_backward_pre_hook")
def test_hook_cpp(self):
bn = nn.BatchNorm1d(5)
def hook(module, grad_inputs, grad_outputs):
self.assertEqual(len(grad_inputs), 1)
self.assertEqual(len(grad_outputs), 1)
self.assertEqual(module, bn)
bn.register_full_backward_hook(hook)
output = bn(torch.randn(5, 5, requires_grad=True))
output.sum().backward()
def test_backward_hooks_interaction(self):
# Test to make sure that the grad_outputs
# updated by full_backward_pre_hook are received by
# the full_backward_hook
module = torch.nn.Sigmoid()
cnt = {"backward_cnt": 0}
def bw_pre_hook(m, grad_output):
cnt["backward_cnt"] += 1
return (grad_output[0] * 0.5,)
def bw_hook(m, grad_in, grad_output):
self.assertEqual(torch.full_like(grad_output[0], 0.5), grad_output[0])
cnt["backward_cnt"] += 1
return grad_output
module.register_full_backward_pre_hook(bw_pre_hook)
module.register_full_backward_hook(bw_hook)
t = torch.ones(1, 2, requires_grad=True)
module(t).sum().backward()
self.assertEqual(cnt["backward_cnt"], 2)
def test_hook_invalid_outputs(self):
module = nn.Sigmoid()
input = torch.randn(5, 5, requires_grad=True)
def bw_fail1(self, grad_input, grad_output):
return grad_input[:-1]
def bw_fail2(self, grad_input, grad_output):
return grad_input + (torch.randn(2, 2),)
with module.register_backward_hook(bw_fail1):
with self.assertRaisesRegex(RuntimeError, "got 0, but expected 1"):
module(input).sum().backward()
with module.register_backward_hook(bw_fail2):
with self.assertRaisesRegex(RuntimeError, "got 2, but expected 1"):
module(input).sum().backward()
def bw_pre_fail1(self, grad_output):
return ()
def bw_pre_fail2(self, grad_output):
return grad_output + (torch.randn(2, 2),)
with module.register_full_backward_pre_hook(bw_pre_fail1):
with self.assertRaisesRegex(RuntimeError, "got 0, but expected 1"):
module(input).sum().backward()
with module.register_full_backward_pre_hook(bw_pre_fail2):
with self.assertRaisesRegex(RuntimeError, "got 2, but expected 1"):
module(input).sum().backward()
def test_hook_requires_grad(self):
test_self = self
class MyModule(nn.Module):
def forward(self, arg1, arg2, arg3):
test_self.assertTrue(arg1.requires_grad)
test_self.assertFalse(arg2.requires_grad)
test_self.assertTrue(arg3.requires_grad)
return arg1.sum() + arg2.sum() + arg3.sum()
inp = torch.rand(2, requires_grad=True)
mod = MyModule()
mod(inp, inp.detach(), inp)
# Ensure that requires grad is properly propagated
mod.register_full_backward_hook(lambda mod, gI, gO: None)
mod(inp, inp.detach(), inp)
def test_hook_no_requires_grad(self):
mod = nn.Linear(2, 3)
inp = torch.rand(1, 2)
return_val = "None"
hook_called = [0]
def hook(mod, grad_input, grad_output):
hook_called[0] += 1
for gI in grad_input:
self.assertIsNone(gI)
for gO in grad_output:
self.assertEqual(gO.size(), (1, 3))
if return_val == "grad_input":
return grad_input
elif return_val == "invalid":
# If the inputs were requiring gradients, this would be
# a valid return
return inp
elif return_val == "None":
return None
else:
raise RuntimeError("Invalid return_val string")
mod.register_full_backward_hook(hook)
# This should run and trigger the hook properly
mod(inp).sum().backward()
self.assertEqual(hook_called[0], 1)
return_val = "grad_input"
mod(inp).sum().backward()
self.assertEqual(hook_called[0], 2)
return_val = "invalid"
with self.assertRaisesRegex(RuntimeError, "where no input requires gradient"):
mod(inp).sum().backward()
def test_hook_last_arg_requires_grad(self):
mod = nn.L1Loss()
inp = torch.rand(1, requires_grad=True)
mod.register_full_backward_hook(lambda m, gI, gO: None)
try:
mod(inp.detach(), inp)
except Exception as ex:
self.fail(f"Unexpected exception: {ex}")
def test_hook_extra_input(self):
class MyModule(nn.Module):
def forward(self, non_tensor, tensor):
return tensor.clone(), non_tensor
inp = torch.rand(2, requires_grad=True)
mod = MyModule()
def hook(mod, grad_input, grad_output):
self.assertIsNone(grad_input[0])
self.assertIsInstance(grad_input[1], torch.Tensor)
self.assertIsInstance(grad_output[0], torch.Tensor)
self.assertIsNone(grad_output[1])
mod.register_full_backward_hook(hook)
out, _ = mod(True, inp)
out.sum().backward()
def test_hook_inplace(self):
class MyModule(nn.Module):
def forward(self, inp, do_inplace):
self.inp = inp
if do_inplace:
inp += 1
return inp.clone()
hook_called = [0]
def hook(mod, grad_input, grad_output):
hook_called[0] += 1
def hook_pre(mod, grad_output):
hook_called[0] += 1
inp = torch.rand(10, requires_grad=True)
mod = MyModule()
for hook_fn, register_fn in [
(hook, mod.register_full_backward_hook),
(hook_pre, mod.register_full_backward_pre_hook),
]:
hook_called[0] = 0
with register_fn(hook_fn):
# No inplace should work
mod(inp, False).sum().backward()
self.assertEqual(hook_called[0], 1)
# Input inplace error should throw an error
with self.assertRaisesRegex(
RuntimeError,
"Output 0 of BackwardHookFunctionBackward is "
"a view and is being modified inplace.",
):
mod(inp.clone(), True)
# Input inplace error should throw an error if we try to re-use the view after they have
# been modified
local_inp = inp.clone()
out = mod(local_inp, False)
local_inp[0] *= 1
with self.assertRaisesRegex(
RuntimeError,
"Output 0 of BackwardHookFunctionBackward is "
"a view and its base or another view",
):
# Any operation involving the view will fail here
mod.inp + 2
# Output inplace error should throw an error
out = mod(inp, False)
with self.assertRaisesRegex(
RuntimeError,
"BackwardHookFunctionBackward is a view "
"and is being modified inplace.",
):
out += 1
def test_hook_non_full_warning(self):
def noop(*args):
pass
a = torch.rand(2, requires_grad=True)
b = torch.rand(2, requires_grad=True)
# Check invalid input container
class MyModule(nn.Module):
def forward(self, l):
return l[0].clone(), l[1].clone()
m = MyModule()
m.register_backward_hook(noop)
with self.assertWarnsRegex(
FutureWarning,
"does not take as input a single Tensor or a tuple of Tensors",
):
m([a, b])
# Check invalid output container
class MyModule(nn.Module):
def forward(self, a, b):
return [a.clone(), b.clone()]
m = MyModule()
m.register_backward_hook(noop)
with self.assertWarnsRegex(
FutureWarning, "does not return a single Tensor or a tuple of Tensors"
):
m(a, b)
# Check invalid output from different Nodes
class MyModule(nn.Module):
def forward(self, a, b):
return a.clone(), b.clone()
m = MyModule()
m.register_backward_hook(noop)
with self.assertWarnsRegex(
FutureWarning, "outputs are generated by different autograd Nodes"
):
m(a, b)
# Check invalid forward with multiple Nodes
class MyModule(nn.Module):
def forward(self, a):
return a.clone().clone()
m = MyModule()
m.register_backward_hook(noop)
with self.assertWarnsRegex(
FutureWarning, "the forward contains multiple autograd Nodes"
):
m(a)
def test_hook_backward_size(self):
# Make module with multiple operations in forward
# And different size for input and outputs
class MyModule(nn.Module):
def forward(self, arg1, arg2):
tmp = arg1.sum() * arg2
tmp = tmp + arg2.sum() * arg1.sum()
tmp = tmp.sum().view(1)
tmp = tmp.expand(8).contiguous()
return tmp
module = MyModule()
inp1 = torch.randn(5, 5, requires_grad=True)
inp2 = torch.randn(10, 10, requires_grad=True)
def bw_hook(module, grad_input, grad_output):
self.assertEqual(len(grad_input), 2)
self.assertEqual(grad_input[0].size(), torch.Size([5, 5]))
self.assertEqual(grad_input[1].size(), torch.Size([10, 10]))
self.assertEqual(len(grad_output), 1)
self.assertEqual(grad_output[0].size(), torch.Size([8]))
with module.register_full_backward_hook(bw_hook):
module(inp1, inp2).sum().backward()
def test_hook_backward_writeable(self):
module = nn.Sigmoid()
input = torch.randn(5, 5, requires_grad=True)
sig_x = torch.nn.functional.sigmoid(input)
def bw_hook(module, grad_input, grad_output):
for grad in grad_input:
self.assertTrue(isinstance(grad, torch.Tensor))
for grad in grad_output:
self.assertTrue(isinstance(grad, torch.Tensor))
return tuple(gi * 2 for gi in grad_input)
module.register_backward_hook(bw_hook)
module(input).backward(torch.ones(5, 5))
expected_grad = sig_x * (1 - sig_x) * 2
self.assertEqual(input.grad, expected_grad)
def test_hook_forward_preforward_writable(self):
module = nn.Sigmoid()
input = torch.randn(5, 5, requires_grad=True)
sig_x = torch.nn.functional.sigmoid(input)
def forward_pre_hook(m, input):
return torch.nn.functional.relu(input[0])
def forward_hook(m, input, output):
return -output
module.register_forward_pre_hook(forward_pre_hook)
module.register_forward_hook(forward_hook)
output = module(input)
expected_res = -torch.nn.functional.sigmoid(torch.nn.functional.relu(input))
self.assertEqual(output, expected_res)
output.backward(torch.ones(5, 5) * 2, retain_graph=True)
mask = input > 0
expected_grad = -sig_x * (1 - sig_x) * 2 * mask
self.assertEqual(input.grad, expected_grad)
def test_hook_buffer_registration(self):
for return_buffer in (True, False):
def buffer_registration_hook(module, name, buffer):
buffer.registered = True
if return_buffer:
return buffer
handle = torch.nn.modules.module.register_module_buffer_registration_hook(
buffer_registration_hook
)
try:
l, n, s = _create_basic_net()
for b in s.buffers():
self.assertTrue(getattr(b, "registered", False))
finally:
handle.remove()
def test_hook_submodule_registration(self):
for return_submodule in (True, False):
def module_registration_hook(module, name, submodule):
module.registered = True
submodule.registered = True
if return_submodule:
return submodule
handle = torch.nn.modules.module.register_module_module_registration_hook(
module_registration_hook
)
try:
l, n, s = _create_basic_net()
for m in s.modules():
self.assertTrue(getattr(m, "registered", False))
finally:
handle.remove()
def test_hook_parameter_registration(self):
for return_parameter in (True, False):
def parameter_registration_hook(module, name, parameter):
parameter.registered = True
if return_parameter:
return parameter
handle = (
torch.nn.modules.module.register_module_parameter_registration_hook(
parameter_registration_hook
)
)
try:
l, n, s = _create_basic_net()
for p in s.parameters():
self.assertTrue(getattr(p, "registered", False))
finally:
handle.remove()
instantiate_parametrized_tests(TestModuleHooks)
instantiate_parametrized_tests(TestStateDictHooks)
if __name__ == "__main__":
run_tests()