| # Owner(s): ["oncall: jit"] |
| |
| import io |
| import unittest |
| from itertools import product |
| from typing import Any |
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.jit._recursive import wrap_cpp_module |
| from torch.testing import FileCheck |
| from torch.testing._internal.common_quantization import skipIfNoFBGEMM |
| from torch.testing._internal.common_quantized import override_quantized_engine |
| from torch.testing._internal.common_utils import set_default_dtype, skipCUDAMemoryLeakCheckIf, TEST_WITH_ROCM |
| from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_CUDA |
| from torch.testing._internal.jit_utils import JitTestCase |
| from torch.utils import mkldnn as mkldnn_utils |
| |
| try: |
| import torchvision |
| HAS_TORCHVISION = True |
| except ImportError: |
| HAS_TORCHVISION = False |
| skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") |
| |
| if __name__ == '__main__': |
| raise RuntimeError("This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead.") |
| |
| TEST_ROCM = torch.cuda.is_available() and torch.version.hip is not None |
| |
| def removeExceptions(graph): |
| for n in graph.findAllNodes('prim::RaiseException'): |
| n.destroy() |
| |
| class TestFreezing(JitTestCase): |
| def test_freeze_module(self): |
| class M(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = 1 # folded |
| self.b = 1.2 # folded |
| self.c = "hello" # folded |
| self.c2 = "hi\xA1" # not folded |
| self.d = [1, 1] # folded |
| self.e = [1.0, 1.1] # folded |
| self.f = ["hello", "world"] # folded |
| self.f2 = [(1, "Over \u0e55\u0e57 57")] |
| self.g = ([1, 2], 3.2, "4.4", torch.tensor([5.5], requires_grad=True)) # folded |
| self.h = {"layer" : [torch.tensor([7.7], requires_grad=True)]} |
| self.h2 = {"layer\xB1" : [torch.tensor([8.8], requires_grad=True)]} |
| self.t = torch.tensor([1.2, 2.4], requires_grad=True) # folded |
| self.ts = [torch.tensor([1.0, 2.0], requires_grad=True), torch.tensor([3.0, 4.0], requires_grad=True)] # folded |
| self.tt = [[torch.tensor([3.3, 2.3], requires_grad=True), None]] |
| |
| def forward(self, x): |
| return str(self.a) + str(self.b) + self.c + self.c2 + str(self.d) + \ |
| str(self.e) + str(self.f) + str(self.f2) + str(self.g) + \ |
| str(self.h) + str(self.h2) + str(self.t) + str(self.ts) + str(self.tt) |
| |
| |
| m = torch.jit.script(M()) |
| m.eval() |
| input = torch.randn(2, 2) |
| output_s = m.forward(input) |
| m._c = torch._C._freeze_module(m._c) |
| buffer = io.BytesIO() |
| torch.jit.save(m._c, buffer) |
| buffer.seek(0) |
| m2 = torch.jit.load(buffer) |
| # Check if frozen module looks as below: |
| # module m { |
| # attributes { |
| # tt = ... |
| # } |
| # ... |
| # } |
| self.assertFalse(m2._c.hasattr('a')) |
| self.assertFalse(m2._c.hasattr('b')) |
| self.assertFalse(m2._c.hasattr('c')) |
| self.assertFalse(m2._c.hasattr('c2')) |
| self.assertFalse(m2._c.hasattr('d')) |
| self.assertFalse(m2._c.hasattr('e')) |
| self.assertFalse(m2._c.hasattr('f')) |
| self.assertFalse(m2._c.hasattr('f2')) |
| self.assertFalse(m2._c.hasattr('g')) |
| self.assertFalse(m2._c.hasattr('h')) |
| self.assertFalse(m2._c.hasattr('h2')) |
| self.assertFalse(m2._c.hasattr('t')) |
| self.assertFalse(m2._c.hasattr('ts')) |
| self.assertFalse(m2._c.hasattr('tt')) |
| output_f = m2.forward(input) |
| self.assertEqual(output_s, output_f) |
| |
| def test_freeze_module_with_submodule(self): |
| class SubModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = 11 |
| self.b = 2 |
| |
| def forward(self, x): |
| return self.a + self.b |
| |
| class SubModule2(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = 12 |
| self.b = 2 |
| |
| def forward(self, x): |
| self.b = 30 |
| return self.a + self.b |
| |
| class TestModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub1 = SubModule() |
| self.sub2 = SubModule2() |
| self.a = 3 |
| self.b = 4 |
| |
| def forward(self, x): |
| self.b = 20 |
| return self.sub1(x) + self.a + self.b + self.sub2(x) |
| |
| m = torch.jit.script(TestModule()) |
| m.eval() |
| input = torch.randn(2, 2) |
| output_s = m.forward(input) |
| mf = torch.jit.freeze(m) |
| |
| # Check if frozen module looks as below: |
| # module m { |
| # attributes { |
| # sub2 = ... |
| # b = |
| # } |
| # ... |
| # submodule { |
| # module m { |
| # attributes { |
| # sub2 = ... |
| # b = |
| # } |
| # ... |
| # } |
| # } |
| # } |
| mf = mf._c |
| self.assertFalse(mf.hasattr('sub1')) |
| self.assertFalse(mf.hasattr('a')) |
| self.assertTrue(mf.hasattr('b')) |
| self.assertTrue(mf.hasattr('sub2')) |
| self.assertTrue(mf.sub2.hasattr('b')) # verify b is preserved in sub2 |
| self.assertFalse(mf.sub2.hasattr('a')) # verify a is removed in sub2 |
| output_f = mf.forward(input) |
| self.assertEqual(output_s, output_f) |
| |
| def test_freeze_module_with_fork(self): |
| class SubModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.ones(20, 20) |
| self.b = torch.ones(20, 20) |
| |
| def forward(self, x): |
| return self.a * self.b + x |
| |
| class TestModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub = SubModule() |
| |
| def forward(self, x): |
| fut = torch.jit._fork(self.sub.forward, x) |
| y_hat = self.sub(x) |
| y = torch.jit._wait(fut) |
| return y_hat + y |
| |
| m = torch.jit.script(TestModule()) |
| m.eval() |
| input = torch.randn(20, 20) |
| output_s = m.forward(input) |
| mf = torch._C._freeze_module(m._c) |
| |
| # Check if frozen module looks as below: |
| # module m { |
| # attributes { |
| # } |
| # ... |
| # submodule { |
| # } |
| # } |
| self.assertFalse(mf.hasattr('a')) |
| self.assertFalse(mf.hasattr('b')) |
| output_f = mf.forward(input) |
| self.assertEqual(output_s, output_f) |
| |
| def test_freeze_module_with_nested_fork(self): |
| class SubModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.ones(20, 20) |
| self.b = torch.ones(20, 20) |
| |
| def forward(self, x): |
| return self.a * self.b + x |
| |
| class SubModule2(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub = SubModule() |
| self.c = torch.ones(20, 20) |
| |
| def forward(self, x): |
| fut = torch.jit._fork(self.sub.forward, x) |
| y_hat = self.sub(x) |
| y = torch.jit._wait(fut) |
| return y_hat + y + self.c |
| |
| class TestModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub = SubModule2() |
| self.d = 1 |
| |
| def forward(self, x): |
| fut = torch.jit._fork(self.sub.forward, x) |
| y_hat = self.sub(x) |
| y = torch.jit._wait(fut) |
| self.d = 2 |
| return y_hat * y + self.d |
| |
| m = torch.jit.script(TestModule()) |
| m.eval() |
| input = torch.randn(20, 20) |
| output_s = m.forward(input) |
| mf = torch._C._freeze_module(m._c) |
| # Check if frozen module looks as below: |
| # module m { |
| # attributes { |
| # } |
| # ... |
| # submodule { |
| # } |
| # } |
| self.assertFalse(mf.hasattr('a')) |
| self.assertFalse(mf.hasattr('b')) |
| self.assertFalse(mf.hasattr('c')) |
| self.assertTrue(mf.hasattr('d')) |
| output_f = mf.forward(input) |
| self.assertEqual(output_s, output_f) |
| |
| |
| def test_freeze_module_with_fork2(self): |
| @torch.jit.script |
| def foo(x): |
| return x * 2 |
| |
| class TestModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.ones(20, 20) |
| self.b = torch.ones(20, 20) |
| |
| def forward(self, x): |
| fut = torch.jit._fork(foo, self.a) |
| y_hat = foo(self.b) |
| y = torch.jit._wait(fut) |
| return y_hat + y |
| |
| m = torch.jit.script(TestModule()) |
| m.eval() |
| input = torch.randn(2, 2) |
| output_s = m.forward(input) |
| mf = torch._C._freeze_module(m._c) |
| |
| # Check if frozen module looks as below: |
| # module m { |
| # attributes { |
| # self.a = ... |
| # self.b = .. |
| # } |
| # ... |
| # submodule { |
| # } |
| # } |
| # TODO: Although there are no mutation, the alias analysis |
| # conservatively assumes there is a mutation because attributes are |
| # passed to fork subgraph. both 'a' and 'b' are preserved. |
| self.assertTrue(mf.hasattr('a')) |
| self.assertFalse(mf.hasattr('b')) |
| output_f = mf.forward(input) |
| self.assertEqual(output_s, output_f) |
| |
| def test_freeze_module_with_fork_calling_module_method(self): |
| @torch.jit.script |
| def foo(x, y): |
| return x * y |
| |
| class TestModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.ones(20, 20) |
| self.b = torch.ones(20, 20) |
| |
| @torch.jit.export |
| def foo(self, x): |
| return x * self.a |
| |
| @torch.jit.export |
| def bar(self, x): |
| return x * self.b |
| |
| def forward(self, x): |
| fut = torch.jit._fork(self.foo, self.b) |
| y_hat = self.bar(self.a) |
| y = torch.jit._wait(fut) |
| return y_hat + y |
| |
| m = torch.jit.script(TestModule()) |
| m.eval() |
| input = torch.randn(2, 2) |
| output_s = m.forward(input) |
| mf = torch._C._freeze_module(m._c) |
| # Check if frozen module looks as below: |
| # module m { |
| # attributes { |
| # self.b = .. |
| # } |
| # ... |
| # TODO: Although there are no mutation, the alias analysis |
| # conservatively assumes there is a mutation because attributes are |
| # passed to fork subgraph. 'b' is preserved. |
| self.assertFalse(mf.hasattr('a')) |
| self.assertTrue(mf.hasattr('b')) |
| output_f = mf.forward(input) |
| self.assertEqual(output_s, output_f) |
| |
| def test_freeze_module_with_sharedclasstype(self): |
| class SubModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.tensor([1.1]) |
| self.b = torch.tensor([2.2]) |
| |
| def forward(self, x): |
| return self.a + self.b |
| |
| @torch.jit.export |
| def modify_a(self, x): |
| self.a[0] += 10 |
| return self. b |
| |
| @torch.jit.export |
| def modify_b(self, x): |
| self.b[0] += 20 |
| return self.a |
| |
| class SubModule2(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub = SubModule() |
| self.b = torch.tensor([3.3]) |
| |
| def forward(self, x): |
| y = self.sub.modify_b(x) |
| return y + self.b |
| |
| class TestModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub1 = SubModule() # sub1 and sub2.sub shared same class type. |
| self.sub2 = SubModule2() |
| self.a = torch.tensor([4.4]) |
| |
| def forward(self, x): |
| z = self.sub1.modify_a(x) |
| return self.sub2(x) + z + self.a |
| |
| m = torch.jit.script(TestModule()) |
| m.eval() |
| input = torch.randn(2, 2) |
| output_s = m.forward(input) |
| mf = torch._C._freeze_module(m._c) |
| |
| # Checking if Frozen module looks as below |
| # module mf { |
| # attributes { |
| # sub1 = ... |
| # sub2 = ... |
| # } |
| # ... |
| # submodules { |
| # module sub1 { |
| # attributes { |
| # a = ... |
| # b = ... |
| # } |
| # ... |
| # } |
| # module sub2 { |
| # attributes { |
| # sub = ... |
| # } |
| # ... |
| # submodule { |
| # module sub { |
| # attributes { |
| # a = ... |
| # b = ... |
| # } |
| # ... |
| # } |
| # } |
| # } |
| # } |
| # } |
| |
| self.assertTrue(mf.hasattr('sub1')) |
| self.assertTrue(mf.sub1.hasattr('a')) |
| self.assertTrue(mf.sub1.hasattr('b')) |
| self.assertFalse(mf.hasattr('a')) |
| self.assertTrue(mf.hasattr('sub2')) |
| self.assertTrue(mf.sub2.hasattr('sub')) |
| self.assertFalse(mf.sub2.hasattr('b')) |
| self.assertTrue(mf.sub2.sub.hasattr('a')) |
| self.assertTrue(mf.sub2.sub.hasattr('b')) |
| output_f = mf.forward(input) |
| self.assertEqual(output_s, output_f) |
| |
| def test_freeze_module_with_nestedaliasing(self): |
| class SubModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.tensor([1.1]) |
| self.b = torch.tensor([2.2]) |
| |
| def forward(self, x): |
| return self.a + self.b |
| |
| @torch.jit.export |
| def modify_a(self, x): |
| self.a[0] = 10 |
| return self. b |
| |
| @torch.jit.export |
| def modify_b(self, x): |
| self.b[0] = 20 |
| return self.a |
| Sub = SubModule() |
| |
| class SubModule2(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub = Sub # aliasing |
| |
| def forward(self, x): |
| return self.sub.a |
| |
| class TestModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub1 = Sub # aliasing |
| self.sub2 = SubModule2() |
| |
| def forward(self, x): |
| z = self.sub1.modify_a(x) |
| return self.sub2(x) + z |
| |
| m = torch.jit.script(TestModule()) |
| m.eval() |
| mf = torch._C._freeze_module(m._c) |
| self.assertTrue(mf.hasattr('sub1')) |
| self.assertTrue(mf.sub1.hasattr('a')) |
| self.assertFalse(mf.sub1.hasattr('b')) |
| self.assertTrue(mf.hasattr('sub2')) |
| self.assertTrue(mf.sub2.hasattr('sub')) |
| self.assertTrue(mf.sub2.sub.hasattr('a')) # Freezing detects that self.sub2.sub.a and self.sub1.a are alias |
| self.assertFalse(mf.sub2.sub.hasattr('b')) |
| input = torch.randn(2, 2) |
| output_s = m.forward(input) |
| output_f = mf.forward(input) |
| self.assertEqual(output_s, output_f) |
| |
| # FIXME: JIT is not honoring aliasing. 'Sub' module is copied. As a result |
| # Eager and Script modules produce different output. |
| def test_freeze_module_with_nestedaliasingscalar(self): |
| class SubModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = 1.1 |
| self.b = 2.2 |
| |
| def forward(self, x): |
| return self.a + self.b |
| |
| @torch.jit.export |
| def modify_a(self, x): |
| self.a = 10.0 |
| return self. b |
| |
| @torch.jit.export |
| def modify_b(self, x): |
| self.b = 20.0 |
| return self.a |
| Sub = SubModule() |
| |
| class SubModule2(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub = Sub # aliasing |
| |
| def forward(self, x): |
| return self.sub.a |
| |
| class TestModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub1 = Sub # aliasing |
| self.sub2 = SubModule2() |
| |
| def forward(self, x): |
| z = self.sub1.modify_a(x) |
| return self.sub2(x) + z |
| m = TestModule() |
| ms = torch.jit.script(m) |
| ms.eval() |
| mf = torch._C._freeze_module(ms._c) |
| self.assertTrue(mf.hasattr('sub1')) |
| self.assertTrue(mf.sub1.hasattr('a')) |
| self.assertFalse(mf.sub1.hasattr('b')) |
| # sub2 is fully folded becasue self.sub1 and self.sub2.sub are not alias (Scripting bug) |
| self.assertFalse(mf.hasattr('sub2')) |
| input = torch.randn(2, 2) |
| output = m.forward(input) |
| output_s = ms.forward(input) |
| output_f = mf.forward(input) |
| # Should be equal |
| self.assertNotEqual(output, output_s) |
| self.assertEqual(output_s, output_f) |
| |
| |
| def test_freeze_module_with_preserve_sub_module(self): |
| class SubModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.tensor([1.1]) |
| self.b = 2.2 |
| |
| def forward(self, x): |
| return self.a |
| |
| class TestModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub1 = SubModule() # aliasing |
| self.sub2 = SubModule() |
| |
| def forward(self, x): |
| return self.sub2(x) + self.sub1(x) |
| m = TestModule() |
| ms = torch.jit.script(m) |
| ms.eval() |
| mf = torch._C._freeze_module(ms._c, ["sub1"]) |
| |
| # Test that 'sub1' is preserved entirely and 'sub2' is completely folded |
| self.assertTrue(mf.hasattr('sub1')) |
| self.assertTrue(mf.sub1.hasattr('a')) |
| self.assertTrue(mf.sub1.hasattr('b')) |
| self.assertFalse(mf.hasattr('sub2')) |
| input = torch.randn(2, 2) |
| output_s = ms.forward(input) |
| output_f = mf.forward(input) |
| self.assertEqual(output_s, output_f) |
| |
| def test_freeze_module_with_preserve_sub_module_and_mutation(self): |
| class SubModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.tensor([1.1]) |
| self.b = 2.2 |
| |
| def forward(self, x): |
| self.a[0] = 3.3 |
| return self.a |
| |
| class TestModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub1 = SubModule() # aliasing |
| self.sub2 = SubModule() |
| |
| def forward(self, x): |
| return self.sub2(x) + self.sub1(x) |
| m = TestModule() |
| ms = torch.jit.script(m) |
| ms.eval() |
| mf = torch._C._freeze_module(ms._c, ["sub1"]) |
| |
| # Test that be both sub1 and sub1 are preserved and 'b' is preserved |
| # even if it is not used. To fulfill user request to preserve 'sub1' |
| self.assertTrue(mf.hasattr('sub1')) |
| self.assertTrue(mf.sub1.hasattr('a')) |
| self.assertTrue(mf.sub1.hasattr('b')) |
| self.assertTrue(mf.hasattr('sub2')) |
| self.assertTrue(mf.sub2.hasattr('a')) |
| self.assertTrue(mf.sub2.hasattr('b')) |
| input = torch.randn(2, 2) |
| output_s = ms.forward(input) |
| output_f = mf.forward(input) |
| self.assertEqual(output_s, output_f) |
| |
| |
| def test_freeze_module_with_helperfunction(self): |
| class SubModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = 11 |
| self.b = 2 |
| |
| def forward(self, x): |
| return self.a + self.b |
| |
| class TestModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub = SubModule() |
| self.a = 3 |
| self.b = 4 |
| |
| def forward(self, x): |
| self.b = 20 |
| return self._forward(x) + self.a + self.b |
| |
| def _forward(self, x): |
| return self.sub(x) |
| m = torch.jit.script(TestModule()) |
| m.eval() |
| input = torch.randn(2, 2) |
| mf = torch._C._freeze_module(m._c) |
| self.assertFalse(mf.hasattr('sub')) |
| self.assertFalse(mf.hasattr('a')) |
| self.assertTrue(mf.hasattr('b')) |
| with self.assertRaisesRegex(AttributeError, "TestModule (.*) does not have a field with name '_forward'"): |
| mf._forward(x) |
| |
| def test_freeze_module_with_inplace_mutable(self): |
| class FreezeMe(torch.jit.ScriptModule): |
| def __init__(self): |
| super().__init__() |
| self.a = [11, 22] |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| for i in range(3): |
| self.a.append(i) |
| return self.a |
| |
| m = FreezeMe() |
| m.eval() |
| m_f = torch._C._freeze_module(m._c) |
| self.assertTrue(m_f.hasattr('a')) |
| m.forward(torch.tensor([3])) |
| out = m_f.forward(torch.tensor([5])) |
| expected = [11, 22, 0, 1, 2, 0, 1, 2] |
| self.assertEqual(out, expected) |
| |
| # Mutable attributes |
| def test_freeze_module_with_mutable_list(self): |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = [1, 2] |
| |
| def forward(self, x): |
| return self.a |
| |
| m = FreezeMe() |
| m.eval() |
| m.a.append(3) |
| m_s = torch.jit.script(m) |
| v = m_s.a |
| v.append(4) |
| m_s.a = v |
| m_s.eval() |
| m_f = torch._C._freeze_module(m_s._c) |
| # Post-freezing mutating m_s.a does not affect m_f (m_f has its own copy). |
| v = m_s.a |
| v.append(5) |
| m_s.a = v |
| self.assertFalse(m_f.hasattr('a')) |
| out = m_f.forward(torch.tensor([5])) |
| expected = [1, 2, 3, 4] |
| self.assertEqual(out, expected) |
| |
| def test_freeze_module_with_mutable_dict(self): |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = {"layer" : "4"} |
| |
| def forward(self, x): |
| return self.a |
| |
| @torch.jit.export |
| def modify_a(self, x): |
| self.a["layer"] = self.a["layer"] + "1" |
| return self.a |
| |
| m = FreezeMe() |
| m.eval() |
| m.a["layer2"] = "3" |
| m_s = torch.jit.script(m) |
| t = torch.tensor(5) |
| m_s.modify_a(t) |
| m_s.eval() |
| m_f = torch._C._freeze_module(m_s._c) |
| m.a["layer2"] += "2" |
| m_s.modify_a(t) |
| self.assertFalse(m_f.hasattr('a')) |
| out = m_f.forward(t) |
| expected = {"layer" : "411", "layer2" : "3"} |
| self.assertEqual(out, expected) |
| |
| def test_freeze_module_with_mutable_tensor(self): |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.tensor([1., 2., 3.]) |
| |
| def forward(self, x): |
| return self.a |
| |
| m = FreezeMe() |
| m_s = torch.jit.script(m) |
| m_s.a[1] += 3.0 |
| m_s.eval() |
| m_f = torch._C._freeze_module(m_s._c) |
| # Post-freezing tensor attribute mutations affect m_f. |
| # FIXME: deep copy all folded attributes so that m_f has full ownership. |
| m_s.a[0] += 5.0 |
| self.assertFalse(m_f.hasattr('a')) |
| out = m_f.forward(torch.tensor([5])) |
| expected = [6., 5., 3.] |
| self.assertEqual(out, expected) |
| |
| def test_freeze_module_with_tuple(self): |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = (torch.tensor([1, 2, 3, 4, 5, 6]), "hi") |
| |
| def forward(self, x): |
| if (x[0] == 2.0): |
| self.a[0][0] = 10 |
| return self.a[0].sum() |
| |
| m = FreezeMe() |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| inp = torch.tensor([2.0]) |
| expected = m_s.forward(inp) |
| m_s.a[0][0] = 1 |
| m_f = torch._C._freeze_module(m_s._c) |
| self.assertFalse(m_f.hasattr('a')) |
| out = m_f.forward(inp) |
| self.assertEqual(out, expected) |
| |
| def test_freeze_module_with_tensor(self): |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.tensor([1, 2, 3, 4, 5, 6]) |
| |
| def forward(self, x): |
| x = self.a.view(2, 3) |
| x[0][0] += 10 |
| return self.a.sum() |
| |
| m = FreezeMe() |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| inp = torch.tensor([5]) |
| expected = m_s.forward(inp) |
| m_f = torch._C._freeze_module(m_s._c) |
| self.assertTrue(m_f.hasattr('a')) |
| m_f.a[0] -= 10 |
| out = m_f.forward(inp) |
| self.assertEqual(out, expected) |
| |
| def test_freeze_module_with_list(self): |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = [torch.tensor([1, 2, 3, 4, 5, 6])] |
| |
| def forward(self, x): |
| self.a[0][1] += 10 |
| return self.a[0].sum() |
| |
| m = FreezeMe() |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| inp = torch.tensor([5]) |
| expected = m_s.forward(inp) |
| m_s.a[0][1] -= 10 |
| m_f = torch._C._freeze_module(m_s._c) |
| self.assertFalse(m_f.hasattr('a')) |
| out = m_f.forward(inp) |
| self.assertEqual(out, expected) |
| |
| def test_freeze_module_with_aliased_tensor_attr(self): |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.tensor([1, 2, 3, 4, 5, 6]) |
| self.b = self.a.view(2, 3) |
| |
| def forward(self, x): |
| self.b[1] += 10 |
| return self.a.sum() |
| |
| m = FreezeMe() |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| m_f = torch._C._freeze_module(m_s._c) |
| self.assertTrue(m_f.hasattr('a')) |
| inp = torch.tensor([5]) |
| out = m_f.forward(inp) |
| expected = torch.tensor(51) # 1+2+3+14+15+16 |
| self.assertEqual(out, expected) |
| |
| def test_freeze_module_with_aliased_tensor_attr2(self): |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.tensor([1, 2, 3, 4, 5, 6]) |
| self.b = {"layer" : ([self.a.view(2, 3), torch.tensor([10])], 20)} |
| self.c = ([self.a.view(2, 3), torch.tensor([10])], 20) |
| self.d = (self.a.view(2, 3), 20) |
| |
| def forward(self, x): |
| self.d[0][0] += 10 |
| return self.a.sum() |
| |
| m = FreezeMe() |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| inp = torch.tensor([5]) |
| expected = m_s.forward(inp) |
| with self.assertRaisesRegex(RuntimeError, "module contains attributes values that overlaps"): |
| m_f = torch._C._freeze_module(m_s._c) |
| |
| def test_freeze_module_with_aliased_tensor_attr3(self): |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.tensor([1, 2, 3, 4, 5, 6]) |
| self.b = [self.a, torch.tensor([10])] |
| |
| def forward(self, x): |
| self.a[1] += 10 |
| return self.b[0].sum() |
| |
| m = FreezeMe() |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| inp = torch.tensor([5]) |
| expected = m_s.forward(inp) |
| m_f = torch._C._freeze_module(m_s._c) |
| self.assertTrue(m_f.hasattr('a')) |
| self.assertTrue(m_f.hasattr('b')) |
| out = m_f.forward(inp) |
| expected += 10 # account for self.a += 10. |
| self.assertEqual(out, expected) |
| |
| def test_freeze_module_with_aliased_tensor_attr4(self): |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.tensor([1, 2, 3, 4, 5, 6]) |
| self.b = [self.a, torch.tensor([10])] |
| |
| def forward(self, x): |
| self.b[0][0] += 10 |
| return self.a.sum() |
| |
| m = FreezeMe() |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| inp = torch.tensor([5]) |
| expected = m_s.forward(inp) |
| m_s.a[0] -= 10 |
| with self.assertRaisesRegex(RuntimeError, "module contains attributes values that overlaps"): |
| m_f = torch._C._freeze_module(m_s._c) |
| |
| def test_freeze_module_with_overlapping_attrs(self): |
| a = torch.tensor([1, 2, 3, 4, 5, 6]) |
| |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.b = [a.view(3, 2), torch.tensor([10])] |
| self.c = (20, a.view(2, 3)) |
| |
| def forward(self, x): |
| self.b[0][0] += 10 |
| return self.c[1].sum() |
| |
| m = FreezeMe() |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| inp = torch.tensor([5]) |
| expected = m_s.forward(inp) |
| a[0] -= 10 |
| with self.assertRaisesRegex(RuntimeError, "module contains attributes values that overlaps"): |
| m_f = torch._C._freeze_module(m_s._c) |
| |
| def test_freeze_module_with_aliased_attr(self): |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = [1, 2, 3, 4, 5, 6] |
| self.b = self.a |
| self.c = (self.a, 10) |
| |
| def forward(self, x): |
| self.b[1] += 10 |
| return str(self.a) + str(self.c) |
| |
| m = FreezeMe() |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| m_f = torch._C._freeze_module(m_s._c) |
| # FIXME: It should be assertTrue. Currently scripting is making a copy for setting self.b (see #33034) |
| self.assertFalse(m_f.hasattr('a')) |
| self.assertFalse(m_f.hasattr('c')) |
| inp = torch.tensor([5]) |
| out = m_f.forward(inp) |
| expected = m_s.forward(inp) |
| self.assertEqual(out, expected) |
| |
| # Check attribute a is preserved. Alias analysis detects that 'a' has output writers. |
| # In this example, 'a' is not mutated. However, we do not track which sub |
| # values of a composite ivalue is mutated. |
| def test_freeze_module_with_aliased_attr2(self): |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = [1, 2, 3, 4, 5, 6] |
| self.b = ([11], [10]) |
| |
| def forward(self, x): |
| v = self.a |
| self.b = (v, [12]) |
| v2 = self.b[1] |
| v2.append(7) |
| return str(v) + str(v2) |
| |
| m = FreezeMe() |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| m_f = torch._C._freeze_module(m_s._c) |
| self.assertTrue(m_f.hasattr('a')) |
| inp = torch.tensor([5]) |
| out = m_f.forward(inp) |
| expected = m.forward(inp) |
| self.assertEqual(out, expected) |
| |
| def test_freeze_module_with_aliased_attr3(self): |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = [1, 2, 3, 4, 5, 6] |
| self.b = ([11], [10]) |
| |
| def forward(self, x): |
| v = self.a |
| v2 = (v, [12]) |
| v3 = v2[0] |
| v3.append(7) |
| return str(self.a) |
| |
| m = FreezeMe() |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| m_f = torch._C._freeze_module(m_s._c) |
| self.assertTrue(m_f.hasattr('a')) |
| inp = torch.tensor([5]) |
| out = m_f.forward(inp) |
| expected = m.forward(inp) |
| self.assertEqual(out, expected) |
| |
| def test_freeze_module_return_self(self): |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.tensor([1., 2., 3.]) |
| |
| def forward(self, x): |
| return self |
| |
| m = FreezeMe() |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| with self.assertRaisesRegex(RuntimeError, "attempted to freeze a module that return itself"): |
| m_f = torch._C._freeze_module(m_s._c) |
| |
| def test_freeze_module_inlining(self): |
| @torch.jit.script # noqa: B903 |
| class Obj: # noqa: B903 |
| def __init__(self, x: int, y: int): |
| self.x = x |
| self.y = y |
| |
| class Mod(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.obj = Obj(2, 3) |
| |
| def forward(self, i: int): |
| print(self.obj) |
| return i |
| |
| mod = torch.jit.freeze(torch.jit.script(Mod().eval())) |
| obj = mod.graph.findNode("prim::Constant") |
| self.assertTrue(torch._C._jit_object_is_non_holding(obj)) |
| |
| buffer = io.BytesIO() |
| torch.jit.save(mod, buffer) |
| buffer.seek(0) |
| |
| loaded = torch.jit.load(buffer) |
| obj = mod.graph.findNode("prim::Constant") |
| self.assertTrue(torch._C._jit_object_is_non_holding(obj)) |
| |
| def test_freeze_module_return_sub_module(self): |
| |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = nn.Conv2d(1, 32, 3, 1) |
| |
| def forward(self, x): |
| return self.conv1 |
| |
| m = FreezeMe() |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| m_f = torch._C._freeze_module(m_s._c) |
| self.assertTrue(m_f.hasattr('conv1')) |
| |
| def test_freeze_module_no_forward(self): |
| |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lin = nn.Linear(10, 1) |
| |
| @torch.jit.export |
| def foo(self, x): |
| return self.lin(x) |
| |
| m = FreezeMe() |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| m_f = torch._C._freeze_module(m_s._c, preservedAttrs=['foo']) |
| input = torch.ones(10) |
| self.assertEqual(m_s.foo(input), m_f.foo(input)) |
| |
| |
| def test_freeze_no_forward(self): |
| |
| class FreezeMe(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lin = nn.Linear(10, 1) |
| |
| @torch.jit.export |
| def foo(self, x): |
| return self.lin(x) |
| |
| m = FreezeMe() |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| m_f = torch.jit.freeze(m_s, preserved_attrs=['foo']) |
| input = torch.ones(10) |
| self.assertEqual(m_s.foo(input), m_f.foo(input)) |
| |
| |
| def test_freeze_module_in_training_mode(self): |
| class Net(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = nn.Conv2d(1, 32, 3, 1) |
| self.conv2 = nn.Conv2d(32, 64, 3, 1) |
| self.dropout1 = nn.Dropout2d(0.25) |
| self.dropout2 = nn.Dropout2d(0.5) |
| self.fc1 = nn.Linear(9216, 128) |
| self.fc2 = nn.Linear(128, 10) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = nn.functional.relu(x) |
| x = self.conv2(x) |
| x = nn.functional.max_pool2d(x, 2) |
| x = self.dropout1(x) |
| x = torch.flatten(x, 1) |
| x = self.fc1(x) |
| x = nn.functional.relu(x) |
| x = self.dropout2(x) |
| x = self.fc2(x) |
| output = nn.functional.log_softmax(x, dim=1) |
| return output |
| |
| model = torch.jit.script(Net()) |
| model.train() |
| mTrain_freezed = torch._C._freeze_module(model._c) |
| # verify mTrain_freezed looks exactly as: |
| # module { |
| # attributes { |
| # conv1 = ... |
| # conv2 = ... |
| # dropout1 = ... |
| # dropout2 = ... |
| # fc1 = ... |
| # fc2 = ... |
| # } |
| # ... |
| # submodules { |
| # module conv1 { |
| # attributes { |
| # weight = ... |
| # bias = ... |
| # } |
| # ... |
| # } |
| # module conv2 { |
| # attributes { |
| # weight = ... |
| # bias = ... |
| # } |
| # ... |
| # } |
| # module dropout1 { |
| # attributes { |
| # training = ... |
| # } |
| # ... |
| # } |
| # module dropout2 { |
| # attributes { |
| # training = ... |
| # } |
| # ... |
| # } |
| # module fc1 { |
| # attributes { |
| # weight = ... |
| # bias = ... |
| # } |
| # ... |
| # } |
| # module fc2 { |
| # attributes { |
| # weight = ... |
| # bias = ... |
| # } |
| # ... |
| # } |
| self.assertFalse(mTrain_freezed.hasattr('training')) |
| self.assertTrue(mTrain_freezed.hasattr('conv1')) |
| self.assertFalse(mTrain_freezed.conv1.hasattr('training')) |
| self.assertTrue(mTrain_freezed.conv1.hasattr('weight')) |
| self.assertTrue(mTrain_freezed.conv1.hasattr('bias')) |
| self.assertTrue(mTrain_freezed.hasattr('conv2')) |
| self.assertFalse(mTrain_freezed.conv2.hasattr('training')) |
| self.assertTrue(mTrain_freezed.conv2.hasattr('weight')) |
| self.assertTrue(mTrain_freezed.conv2.hasattr('bias')) |
| self.assertTrue(mTrain_freezed.hasattr('dropout1')) |
| self.assertTrue(mTrain_freezed.dropout1.hasattr('training')) |
| self.assertTrue(mTrain_freezed.hasattr('dropout2')) |
| self.assertTrue(mTrain_freezed.dropout2.hasattr('training')) |
| self.assertTrue(mTrain_freezed.hasattr('fc1')) |
| self.assertTrue(mTrain_freezed.fc1.hasattr('weight')) |
| self.assertTrue(mTrain_freezed.fc1.hasattr('bias')) |
| self.assertTrue(mTrain_freezed.hasattr('fc2')) |
| self.assertTrue(mTrain_freezed.fc2.hasattr('weight')) |
| self.assertTrue(mTrain_freezed.fc2.hasattr('bias')) |
| model.eval() |
| mEval_freezed = torch._C._freeze_module(model._c) |
| self.assertFalse(mEval_freezed.hasattr('conv1')) |
| self.assertFalse(mEval_freezed.hasattr('conv2')) |
| self.assertFalse(mEval_freezed.hasattr('dropout1')) |
| self.assertFalse(mEval_freezed.hasattr('training')) |
| self.assertFalse(mEval_freezed.hasattr('fc1')) |
| self.assertFalse(mEval_freezed.hasattr('dropout2')) |
| self.assertFalse(mEval_freezed.hasattr('fc2')) |
| with self.assertRaisesRegex(AttributeError, "does not have a field with name 'state_dict'"): |
| print(mEval_freezed.state_dict()) |
| buffer = io.BytesIO() |
| torch.jit.save(mEval_freezed, buffer) |
| buffer.seek(0) |
| m = torch.jit.load(buffer) |
| FileCheck().check_not('GetAttr[name=') \ |
| .run(m._c._get_method('forward').graph) |
| m2 = torch._C._freeze_module(model._c, preserveParameters=True) |
| self.assertTrue(m2.hasattr('conv1')) |
| self.assertTrue(m2.hasattr('conv2')) |
| self.assertFalse(m2.hasattr('dropout1')) |
| self.assertFalse(m2.hasattr('training')) |
| self.assertTrue(m2.hasattr('fc1')) |
| self.assertFalse(m2.hasattr('dropout2')) |
| self.assertTrue(m2.hasattr('fc2')) |
| |
| def test_freeze_module_detach_gradient(self): |
| mod = nn.Conv2d(8, 3, 4, 2, 1) |
| self.assertTrue(mod.weight.requires_grad) |
| smod = torch.jit.script(mod) |
| smod.eval() |
| fmod = torch._C._freeze_module(smod._c) |
| self.assertTrue(mod.weight.requires_grad) |
| self.assertTrue(smod.weight.requires_grad) |
| self.assertFalse(fmod.hasattr('weight')) |
| inp = torch.ones(1, 8, 32, 32) |
| out1 = fmod.forward(inp) |
| # FIXME: frozen module mutated from outside (original module). |
| with torch.no_grad(): |
| smod.weight[0, 0, 0, 0] += 100.0 |
| out2 = fmod.forward(inp) |
| out3 = smod(inp) |
| self.assertNotEqual(out1, out2) |
| self.assertEqual(out2, out3) |
| |
| def test_freeze_module_with_user_preserved_attr(self): |
| class Module(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.tensor([1.1]) |
| self.b = torch.tensor([2.2]) |
| |
| def forward(self, x): |
| return self.a + self.b |
| |
| m = torch.jit.script(Module()) |
| m.eval() |
| fm = torch._C._freeze_module(m._c, ["a"]) |
| # Attribute "a" is preserved |
| self.assertTrue(fm.hasattr("a")) |
| self.assertFalse(fm.hasattr("b")) |
| |
| def test_freeze_module_with_user_preserved_method(self): |
| class Module(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.tensor([1.1]) |
| self.b = torch.tensor([2.2]) |
| |
| def forward(self, x): |
| return self.a + self.b |
| |
| @torch.jit.export |
| def modify_a(self, x): |
| self.a[0] += 10 |
| return self.b |
| |
| @torch.jit.export |
| def modify_b(self, x): |
| self.b[0] += 20 |
| return self.a |
| |
| m = torch.jit.script(Module()) |
| m.eval() |
| fm = torch._C._freeze_module(m._c, ["modify_a"]) |
| # Both attribute "a" and method "modify_a" are preserved |
| self.assertTrue(fm.hasattr("a")) |
| self.assertFalse(fm.hasattr("b")) |
| input = torch.randn(2, 2) |
| expected = m.forward(input) |
| out = fm.forward(input) |
| self.assertEqual(out, expected) |
| |
| def test_freeze_module_with_user_preserved_method2(self): |
| class Module(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = torch.tensor([1.1]) |
| self.b = torch.tensor([2.2]) |
| |
| def forward(self, x): |
| self.b += 10 |
| return self.a + self.b |
| |
| @torch.jit.export |
| def modify_a(self, x): |
| self.a[0] += 10 |
| return self.b + self.a |
| |
| m = torch.jit.script(Module()) |
| m.eval() |
| fm = torch._C._freeze_module(m._c, ["modify_a"]) |
| FileCheck().check('prim::GetAttr[name="a"]').run(fm.forward.graph) |
| FileCheck().check('prim::GetAttr[name="b"]').run(fm.modify_a.graph) |
| |
| def test_freeze_module_with_user_preserved_attribute_on_submodule(self): |
| class SubModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = 1 |
| self.b = 2 |
| |
| def forward(self): |
| return self.a + self.b |
| |
| class Module(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub1 = SubModule() |
| self.sub2 = SubModule() |
| |
| def forward(self): |
| return self.sub1() + self.sub2() |
| |
| m = torch.jit.script(Module()) |
| m.eval() |
| m = torch.jit.freeze(m, preserved_attrs=['sub1.a', 'sub2.a']) |
| fm = m._c |
| |
| self.assertTrue(fm.hasattr('sub1')) |
| self.assertTrue(fm.sub1.hasattr('a')) |
| self.assertFalse(fm.sub1.hasattr('b')) |
| self.assertTrue(fm.hasattr('sub2')) |
| self.assertTrue(fm.sub2.hasattr('a')) |
| self.assertFalse(fm.sub2.hasattr('b')) |
| self.assertEqual(m(), 6) |
| m.sub1.a += 1 |
| self.assertEqual(m(), 7) |
| |
| def test_freeze_module_with_user_preserved_attribute_on_unused_submodule(self): |
| class SubModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = 1 |
| self.b = 2 |
| |
| def forward(self): |
| return self.a + self.b |
| |
| @torch.jit.export |
| def method_a(self): |
| return 42 |
| |
| class Module(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub = SubModule() |
| |
| def forward(self): |
| return 1 |
| |
| m = torch.jit.script(Module()) |
| m.eval() |
| fm = torch.jit.freeze(m, preserved_attrs=['sub.a', 'sub.method_a'])._c |
| |
| self.assertTrue(fm.hasattr('sub')) |
| self.assertTrue(fm.sub.hasattr('a')) |
| self.assertFalse(fm.sub.hasattr('b')) |
| self.assertTrue(fm.sub._has_method('method_a')) |
| |
| def test_freeze_module_with_user_preserved_method_on_submodule(self): |
| class SubModule(nn.Module): |
| def forward(self, x): |
| return self.method_a(x) + self.method_b(x) |
| |
| def method_a(self, x): |
| return x * x |
| |
| def method_b(self, x): |
| return x + x |
| |
| class Module(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub = SubModule() |
| |
| def forward(self, x): |
| return self.sub(x) |
| |
| m = torch.jit.script(Module()) |
| m.eval() |
| fm = torch.jit.freeze(m, preserved_attrs=['sub.method_a'])._c |
| |
| self.assertTrue(fm.hasattr('sub')) |
| self.assertTrue(fm.sub._has_method('method_a')) |
| self.assertFalse(fm.sub._has_method('method_b')) |
| |
| @skipIfNoFBGEMM |
| def test_module_with_shared_type_instances(self): |
| class Child(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| return x |
| |
| class Parent(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.quant = torch.ao.quantization.QuantStub() |
| self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32) |
| self.child = Child() |
| self.child2 = Child() |
| self.dequant = torch.ao.quantization.DeQuantStub() |
| |
| def forward(self, x): |
| x = self.quant(x) |
| x = self.conv1(x) |
| x = self.child(x) |
| x = self.child2(x) |
| x = self.dequant(x) |
| return x |
| |
| def _static_quant(model): |
| qModel = torch.ao.quantization.QuantWrapper(model) |
| qModel.qconfig = torch.ao.quantization.default_qconfig |
| torch.ao.quantization.prepare(qModel, inplace=True) |
| qModel(torch.rand(4, 1, 4, 4, dtype=torch.float32)) |
| torch.ao.quantization.convert(qModel, inplace=True) |
| return model |
| |
| with override_quantized_engine('fbgemm'): |
| data = torch.randn(4, 1, 4, 4, dtype=torch.float32) |
| m = Parent().to(torch.float32) |
| m = _static_quant(m) |
| m = torch.jit.script(m) |
| m.eval() |
| torch._C._jit_pass_inline(m.graph) |
| m_frozen = wrap_cpp_module(torch._C._freeze_module(m._c)) |
| # Earlier bug resulted in _packed_params set to false. |
| FileCheck().check_not('_packed_params = False').run(m_frozen._c.dump_to_str(True, True, False)) |
| |
| m_res = m(data) |
| # It used to segfault while running frozen module. |
| m_frozen_res = m_frozen(data) |
| self.assertEqual(m_res, m_frozen_res) |
| |
| def test_module_getattr_indirection(self): |
| @torch.jit.script |
| class ValHolder: |
| def __init__(self, val: int): |
| self.val: int = val |
| |
| class Mod(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mod1 = ValHolder(1) |
| self.mod2 = ValHolder(2) |
| |
| def forward(self, cond: bool): |
| if cond: |
| mod = self.mod1 |
| else: |
| mod = self.mod2 |
| return mod.val |
| |
| mod = Mod() |
| mod.eval() |
| frozen_mod = torch.jit.freeze(torch.jit.script(mod)) |
| mod_eager = Mod() |
| self.assertEqual(mod_eager(True), frozen_mod(True)) |
| self.assertEqual(mod_eager(False), frozen_mod(False)) |
| |
| def test_freeze_module_with_non_static_module_container_index(self): |
| """ |
| Test that Modules containing non-static ModuleDict or ModuleList |
| indexing cannot be frozen. |
| """ |
| @torch.jit.interface |
| class ModuleInterface(torch.nn.Module): |
| def forward(self, inp: Any) -> Any: |
| pass |
| |
| class ImplementsInterface(torch.nn.Module): |
| def forward(self, inp: Any) -> Any: |
| if isinstance(inp, torch.Tensor): |
| return torch.max(inp, dim=0) |
| |
| return inp |
| |
| class ModWithDict(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.d = torch.nn.ModuleDict({"module": ImplementsInterface()}) |
| |
| def forward(self, x: torch.Tensor, key: str) -> Any: |
| value: ModuleInterface = self.d[key] |
| return value.forward(x) |
| |
| m = torch.jit.script(ModWithDict()) |
| m.eval() |
| with self.assertRaisesRegex(RuntimeError, "Freezing modules containing prim::ModuleContainerIndex is not supported"): |
| mf = torch._C._freeze_module(m._c) |
| |
| class ModWithList(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.l = torch.nn.ModuleList([ImplementsInterface()]) |
| |
| def forward(self, x: torch.Tensor, idx: int) -> Any: |
| value: ModuleInterface = self.l[idx] |
| return value.forward(x) |
| |
| m = torch.jit.script(ModWithList()) |
| m.eval() |
| with self.assertRaisesRegex(RuntimeError, "Freezing modules containing prim::ModuleContainerIndex is not supported"): |
| mf = torch._C._freeze_module(m._c) |
| |
| def test_freeze_with_interface_mutable(self): |
| @torch.jit.interface |
| class ModuleInterface(torch.nn.Module): |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| pass |
| |
| class ImplementsInterface(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sum = torch.zeros((2, 2)) |
| |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| self.sum += inp.relu() |
| return self.sum |
| |
| class WrapperModule(torch.nn.Module): |
| impl: ModuleInterface |
| |
| def __init__(self): |
| super().__init__() |
| self.impl = ImplementsInterface() |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.impl.forward(x) |
| |
| m = torch.jit.script(WrapperModule()) |
| m.eval() |
| m_frozen = torch.jit.freeze(m) |
| |
| x = torch.rand((2, 2)) |
| |
| m_frozen(x) |
| self.assertEqual(m_frozen.impl.sum, x.relu()) |
| |
| def test_freeze_with_swapping_interfaces(self): |
| @torch.jit.interface |
| class ModuleInterface(torch.nn.Module): |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| pass |
| |
| class Implementation1(torch.nn.Module): |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| return inp.relu() |
| |
| class Implementation2(torch.nn.Module): |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| return inp.sin() |
| |
| class WrapperModule(torch.nn.Module): |
| impl: ModuleInterface |
| |
| def __init__(self): |
| super().__init__() |
| self.option1 = Implementation1() |
| self.option2 = Implementation2() |
| self.impl = self.option1 |
| self.idx = 0 |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| self.idx += 1 |
| if self.idx % 2 == 1: |
| self.impl = self.option1 |
| else: |
| self.impl = self.option2 |
| return self.impl(x) |
| |
| m = torch.jit.script(WrapperModule()) |
| m.eval() |
| with self.assertRaisesRegex(RuntimeError, "Freezing does not support SetAttr on an interface type"): |
| m_frozen = torch.jit.freeze(m) |
| |
| def test_freeze_recursive_interfaces(self): |
| @torch.jit.interface |
| class InnerInterface(torch.nn.Module): |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| pass |
| |
| @torch.jit.interface |
| class OuterInterface(torch.nn.Module): |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| pass |
| |
| class InnerImpl(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.x = torch.ones((2, 2)) |
| |
| def forward(self, inp): |
| return inp.cos() * self.x |
| |
| class OuterImpl(torch.nn.Module): |
| inner_impl: InnerInterface |
| |
| def __init__(self): |
| super().__init__() |
| self.inner_impl = InnerImpl() |
| |
| def forward(self, inp): |
| return inp.relu() + self.inner_impl(inp.sin()) |
| |
| class WrapperModule(torch.nn.Module): |
| outer_impl: OuterInterface |
| |
| def __init__(self): |
| super().__init__() |
| self.outer_impl = OuterImpl() |
| |
| def forward(self, inp): |
| return self.outer_impl(inp) + inp |
| |
| m = WrapperModule() |
| x = torch.rand((2, 2)) |
| expected = m(x) |
| |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| m_s = torch.jit.freeze(m_s) |
| actual = m_s(x) |
| |
| self.assertEqual(expected, actual) |
| |
| def test_freeze_recursive_interfaces_with_reassignment(self): |
| @torch.jit.interface |
| class InnerInterface(torch.nn.Module): |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| pass |
| |
| @torch.jit.interface |
| class OuterInterface(torch.nn.Module): |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| pass |
| |
| class InnerImpl1(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.x = torch.ones((2, 2)) |
| |
| def forward(self, inp): |
| return inp.cos() * self.x |
| |
| |
| class InnerImpl2(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.x = torch.ones((2, 2)) * 2 |
| |
| def forward(self, inp): |
| return inp.sin() / self.x |
| |
| class OuterImpl(torch.nn.Module): |
| inner_impl: InnerInterface |
| |
| def __init__(self): |
| super().__init__() |
| self.inner_impl = InnerImpl1() |
| self.impl1 = InnerImpl1() |
| self.impl2 = InnerImpl1() |
| self.idx = 0 |
| |
| def forward(self, inp): |
| self.idx += 1 |
| if self.idx % 2 == 0: |
| self.inner_impl = self.impl1 |
| else: |
| self.inner_impl = self.impl2 |
| return inp.relu() + self.inner_impl(inp.sin()) |
| |
| class WrapperModule(torch.nn.Module): |
| outer_impl: OuterInterface |
| |
| def __init__(self): |
| super().__init__() |
| self.outer_impl = OuterImpl() |
| |
| def forward(self, inp): |
| return self.outer_impl(inp) + inp |
| |
| m = WrapperModule() |
| |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| with self.assertRaisesRegex(RuntimeError, "Freezing does not support SetAttr on an interface type"): |
| m_s = torch.jit.freeze(m_s) |
| |
| def test_freeze_interface_swapping_two_methods(self): |
| @torch.jit.interface |
| class MyInterface(torch.nn.Module): |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| pass |
| |
| class Impl1(torch.nn.Module): |
| def forward(self, inp): |
| return inp.cos() |
| |
| class Impl2(torch.nn.Module): |
| def forward(self, inp): |
| return inp.sin() |
| |
| class WrapperModule1(torch.nn.Module): |
| interface_impl: MyInterface |
| |
| def __init__(self): |
| super().__init__() |
| self.interface_impl = Impl1() |
| self.impl1 = Impl1() |
| self.impl2 = Impl2() |
| self.idx = 0 |
| |
| def forward(self, x): |
| return self.interface_impl(x) |
| |
| @torch.jit.export |
| def other_method(self, x): |
| self.idx += 1 |
| if self.idx % 2 == 0: |
| self.interface_impl = self.impl1 |
| else: |
| self.interface_impl = self.impl2 |
| return self.interface_impl(x) |
| |
| class WrapperModule2(torch.nn.Module): |
| interface_impl: MyInterface |
| |
| def __init__(self): |
| super().__init__() |
| self.interface_impl = Impl1() |
| self.impl1 = Impl1() |
| self.impl2 = Impl2() |
| self.idx = 0 |
| |
| def forward(self, x): |
| self.idx += 1 |
| if self.idx % 2 == 0: |
| self.interface_impl = self.impl1 |
| else: |
| self.interface_impl = self.impl2 |
| return self.interface_impl(x) |
| |
| @torch.jit.export |
| def other_method(self, x): |
| return self.interface_impl(x) |
| |
| m1 = torch.jit.script(WrapperModule1()) |
| m2 = torch.jit.script(WrapperModule2()) |
| |
| m1.eval() |
| m2.eval() |
| |
| with self.assertRaisesRegex(RuntimeError, "Freezing does not support SetAttr on an interface type"): |
| torch.jit.freeze(m1, preserved_attrs=["other_method"]) |
| |
| with self.assertRaisesRegex(RuntimeError, "Freezing does not support SetAttr on an interface type"): |
| torch.jit.freeze(m2, preserved_attrs=["other_method"]) |
| |
| def test_freeze_recursive_interfaces_same_name(self): |
| @torch.jit.interface |
| class InnerInterface(torch.nn.Module): |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| pass |
| |
| @torch.jit.interface |
| class OuterInterface(torch.nn.Module): |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| pass |
| |
| class InnerImpl(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.x = torch.ones((2, 2)) |
| |
| def forward(self, inp): |
| return inp.cos() * self.x |
| |
| class OuterImpl(torch.nn.Module): |
| impl: InnerInterface |
| |
| def __init__(self): |
| super().__init__() |
| self.impl = InnerImpl() |
| self.x = torch.ones((2, 2)) * 5 |
| |
| def forward(self, inp): |
| return self.other_method(inp) |
| |
| def other_method(self, inp): |
| return inp.relu() + self.impl(inp.sin()) + self.x |
| |
| class WrapperModule(torch.nn.Module): |
| impl: OuterInterface |
| |
| def __init__(self): |
| super().__init__() |
| self.impl = OuterImpl() |
| |
| def forward(self, inp): |
| return self.impl(inp) + inp |
| |
| m = WrapperModule() |
| x = torch.rand((2, 2)) |
| expected = m(x) |
| |
| m_s = torch.jit.script(m) |
| m_s.eval() |
| m_s = torch.jit.freeze(m_s) |
| actual = m_s(x) |
| |
| self.assertEqual(expected, actual) |
| |
| def test_freeze_non_interface_module_swap(self): |
| class InnerModule(torch.nn.Module): |
| def __init__(self, x): |
| super().__init__() |
| self.x = x |
| |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| return inp.relu() + self.x |
| |
| class WrapperModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.option1 = InnerModule(torch.rand((2, 2))) |
| self.option2 = InnerModule(torch.rand((2, 2))) |
| self.impl = self.option1 |
| self.idx = 0 |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| self.idx += 1 |
| if self.idx % 2 == 1: |
| self.impl = self.option1 |
| else: |
| self.impl = self.option2 |
| return self.impl(x) |
| |
| unfrozen = WrapperModule() |
| m = torch.jit.script(unfrozen) |
| m.eval() |
| m_frozen = torch.jit.freeze(m) |
| |
| x = torch.rand((2, 2)) |
| expected = unfrozen(x) |
| actual = m_frozen(x) |
| self.assertEqual(expected, actual) |
| |
| @unittest.expectedFailure |
| def test_freeze_interface_within_object(self): |
| # I don't think there's any way to create a plain python object that |
| # contains a torch.nn.Module inside it, but just in case... I'm not |
| # sure freezing would handle this case correctly, so marking as xfail |
| # so that if this ever _does_ start working someone will need to |
| # investigate to make sure this is handled correctly. |
| class MyIface(torch.nn.Module): |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| pass |
| |
| class MyImpl(torch.nn.Module): |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| return inp.sin() |
| |
| class MyObject: |
| impl: MyIface |
| |
| def run(self, x): |
| return self.impl(x) |
| |
| class WrapperModule(torch.nn.Module): |
| impl: MyObject |
| |
| def __init__(self): |
| super().__init__() |
| self.impl = MyObject() |
| self.impl.impl = MyImpl() |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.impl(x) |
| |
| unfrozen = WrapperModule() |
| m = torch.jit.script(unfrozen) |
| m.eval() |
| m_frozen = torch.jit.freeze(m) |
| |
| x = torch.rand((2, 2)) |
| expected = unfrozen(x) |
| actual = m_frozen(x) |
| self.expectEqual(expected, actual) |
| |
| def test_freeze_non_module_class_getattr(self): |
| class BoxCoder: |
| def __init__(self, bbox_xform_clip): |
| # type: (float) -> None |
| self.bbox_xform_clip = bbox_xform_clip |
| |
| def decode(self, input): |
| return input * self.bbox_xform_clip |
| |
| class MyModule(torch.nn.Module): |
| __annotations__ = { |
| 'box_coder': BoxCoder, |
| } |
| |
| def __init__(self): |
| super().__init__() |
| self.box_coder = BoxCoder(50.) |
| |
| def forward(self, input): |
| return self.box_coder.decode(input) |
| |
| model = MyModule() |
| model.eval() |
| script_model = torch.jit.freeze(torch.jit.script(model)) |
| inp = torch.randn([4, 4]) |
| output_eager = model(inp) |
| self.assertEqual(model(inp), script_model(inp)) |
| FileCheck().check_not("GetAttr").run(script_model.graph) |
| |
| def test_freeze_module_with_tupleoutput_submodule(self): |
| class SubModule(nn.Module): |
| def forward(self, x): |
| return (x + 1, x + 2) |
| |
| class TestModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sub = SubModule() |
| |
| def forward(self, x): |
| y1, y2 = self.sub(x) |
| return y1 + y2 |
| |
| m = torch.jit.script(TestModule()) |
| m = m.eval() |
| mf = torch.jit.freeze(m) |
| inp = torch.randn(2, 2) |
| expected = m.forward(inp) |
| output = mf.forward(inp) |
| # Check if prim::TupleConstruct and prim::TupleUnpack |
| # Don't exist in frozen graph |
| FileCheck().check_not("prim::TupleConstruct").run(mf.graph) |
| FileCheck().check_not("prim::TupleUnpack").run(mf.graph) |
| self.assertEqual(output, expected) |
| |
| def test_freeze_module_with_call_method(self): |
| class Mod(nn.Module): |
| def __init__(self, val): |
| super().__init__() |
| self.param = nn.Parameter(val) |
| |
| def forward(self, x): |
| # this method will change during freezing |
| return x + self.param |
| |
| @torch.jit.export |
| def make_prediction(self, x): |
| y = x + x |
| return self.forward(y) |
| |
| param = torch.rand([2, 2]) |
| x = torch.rand([2, 2]) |
| |
| unscripted_mod = Mod(param) |
| mod = torch.jit.script(unscripted_mod) |
| mod.eval() |
| mod = torch.jit.freeze(mod, preserved_attrs=["make_prediction"]) |
| |
| self.assertEqual( |
| mod.forward(x), unscripted_mod.forward(x), atol=1e-5, rtol=1e-5 |
| ) |
| |
| class TestFrozenOptimizations(JitTestCase): |
| def setUp(self): |
| super().setUp() |
| self.default_dtype = torch.get_default_dtype() |
| torch.set_default_dtype(torch.double) |
| |
| def tearDown(self): |
| super().tearDown() |
| torch.set_default_dtype(self.default_dtype) |
| |
| def test_conv_bn_folding(self): |
| conv_bias = [True, False] |
| module_pairs = [(nn.Conv1d, nn.BatchNorm1d), (nn.Conv2d, nn.BatchNorm2d), (nn.Conv3d, nn.BatchNorm3d)] |
| use_tracing = [True, False] |
| bn_running_stats = [True, False] |
| |
| for use_bias, modules, tracing, track_stats in product(conv_bias, module_pairs, use_tracing, bn_running_stats): |
| class ConvBN(torch.nn.Module): |
| def __init__(self, in_channels, out_channels, **kwargs): |
| super().__init__() |
| self.conv = modules[0](in_channels, out_channels, bias=use_bias, **kwargs) |
| self.bn = modules[1](out_channels, eps=0.001, track_running_stats=track_stats) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return self.bn(x) |
| |
| mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval() |
| inps = [4, 3, 4] |
| if modules[0] == nn.Conv2d: |
| inps.append(inps[-1]) |
| if modules[0] == nn.Conv3d: |
| inps.append(inps[-1]) |
| inps.append(inps[-1]) |
| |
| inp = torch.rand(inps) |
| |
| if tracing: |
| scripted_mod = torch.jit.trace(mod_eager, (inp)) |
| else: |
| scripted_mod = torch.jit.script(mod_eager) |
| |
| self.run_pass("inline", scripted_mod.graph) |
| self.run_pass("peephole", scripted_mod.graph) |
| self.run_pass("constant_propagation", scripted_mod.graph) |
| |
| FileCheck().check("conv").check("batch").run(scripted_mod.graph) |
| # successfully no-ops with non-const inputs |
| self.run_pass("fold_frozen_conv_bn", scripted_mod.graph) |
| FileCheck().check("conv").check("aten::batch_norm").run(scripted_mod.graph) |
| |
| scripted_mod = torch.jit.freeze(scripted_mod) |
| self.run_pass("fold_frozen_conv_bn", scripted_mod.graph) |
| if track_stats: |
| FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.graph) |
| else: |
| FileCheck().check("conv").check("aten::batch_norm").run(scripted_mod.graph) |
| |
| self.assertEqual(mod_eager(inp), scripted_mod(inp)) |
| self.assertEqual(mod_eager(inp), scripted_mod(inp)) |
| |
| def test_conv_bn_folding_not_forward(self): |
| class ConvBN(torch.nn.Module): |
| def __init__(self, in_channels, out_channels, **kwargs): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=True, **kwargs) |
| self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001) |
| self.amt = 3.2 |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return self.bn(x) |
| |
| @torch.jit.export |
| def make_prediction(self, x): |
| return self.forward(x) + self.amt |
| |
| mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval() |
| scripted_mod = torch.jit.script(mod_eager) |
| torch._C._jit_pass_inline(scripted_mod.make_prediction.graph) |
| FileCheck().check("conv").check("aten::batch_norm").run(scripted_mod.make_prediction.graph) |
| |
| # _jit_pass_optimize_frozen_graph should not be called on non-method attributes (e.g. "amt") |
| scripted_mod = torch.jit.freeze(scripted_mod, preserved_attrs=["make_prediction", "amt"]) |
| FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.make_prediction.graph) |
| |
| # During freezing this creates tensors constants that are attached to the frozen graph, |
| # which is then kept alive by the compilation unit (which causes a leak) |
| @skipCUDAMemoryLeakCheckIf(True) |
| @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") |
| def test_conv_bn_folding_autocast_scenario_cuda(self): |
| # CUDA conv takes input tensors which must all be the same dtype, |
| # which can cause issues if folding produces inputs of different dtypes. |
| |
| class ConvBN(torch.nn.Module): |
| def __init__(self, in_channels, out_channels, **kwargs): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, dtype=torch.half, **kwargs) |
| self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001, dtype=torch.float) |
| |
| def forward(self, x): |
| return self.bn(self.conv(x)) |
| |
| mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).cuda().eval() |
| scripted_mod = torch.jit.script(mod_eager) |
| scripted_mod = torch.jit.freeze(scripted_mod) |
| FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.graph) |
| conv_node = scripted_mod.graph.findNode("aten::conv2d", True) |
| self.assertTrue(conv_node is not None) |
| bias_input = conv_node.namedInput("bias") |
| self.assertTrue(bias_input is not None) |
| self.assertTrue(bias_input.type().dtype() == torch.half) |
| |
| x = torch.rand((3, 3, 32, 32), dtype=torch.half).cuda() |
| |
| self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2) |
| self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2) |
| |
| def test_conv_add_folding(self): |
| |
| @torch.no_grad() |
| def test_conv_fusion(use_bias, module, tracing, op, scalar, add_tensor, expect_success): |
| |
| class ConvOp(torch.nn.Module): |
| __constants__ = ['use_scalar'] |
| |
| def __init__(self, in_channels, out_channels, tensor=None, **kwargs): |
| super().__init__() |
| self.conv = module(in_channels, out_channels, bias=use_bias, **kwargs) |
| self.conv2 = module(in_channels, out_channels, bias=use_bias, **kwargs) |
| self.use_scalar = scalar |
| tensor_size = [1 for _ in range(self.conv.weight.ndim)] |
| tensor_size[1] = self.conv.weight.size(0) |
| self.tensor = add_tensor if add_tensor is not None else torch.rand(tensor_size) |
| self.op = op |
| |
| def forward(self, x): |
| x = self.conv(x) |
| if self.use_scalar: |
| return self.op(x, 2.) |
| else: |
| return self.op(x, self.tensor) |
| |
| mod_eager = ConvOp(3, 32, kernel_size=3, stride=2).eval() |
| |
| inps = [4, 3, 4] |
| if module == nn.Conv2d: |
| inps.append(inps[-1]) |
| if module == nn.Conv3d: |
| inps.append(inps[-1]) |
| inps.append(inps[-1]) |
| |
| |
| inp = torch.rand(inps) |
| |
| if tracing: |
| scripted_mod = torch.jit.trace(mod_eager, (inp,)) |
| else: |
| scripted_mod = torch.jit.script(mod_eager) |
| |
| self.run_pass("inline", scripted_mod.graph) |
| op_str = "aten::" + op.__name__ |
| |
| FileCheck().check("conv").check(op_str).run(scripted_mod.graph) |
| # successively no-ops with non-const inputs |
| self.run_pass("fold_frozen_conv_mul_or_div", scripted_mod.graph) |
| self.run_pass("fold_frozen_conv_add_or_sub", scripted_mod.graph) |
| FileCheck().check("conv").check(op_str).run(scripted_mod.graph) |
| scripted_mod = torch.jit.freeze(scripted_mod) |
| self.run_pass("fold_frozen_conv_mul_or_div", scripted_mod.graph) |
| self.run_pass("fold_frozen_conv_add_or_sub", scripted_mod.graph) |
| |
| if expect_success: |
| FileCheck().check("conv").check_not(op_str).run(scripted_mod.graph) |
| else: |
| FileCheck().check("conv").check(op_str).run(scripted_mod.graph) |
| |
| self.assertEqual(mod_eager(inp), scripted_mod(inp)) |
| self.assertEqual(mod_eager(inp), scripted_mod(inp)) |
| |
| conv_bias = [True, False] |
| modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d] |
| use_tracing = [False, True] |
| use_scalar = [False, True] |
| ops = [torch.add, torch.sub, torch.mul, torch.div] |
| |
| for use_bias, module, tracing, pytorch_op, scalar in product(conv_bias, modules, use_tracing, ops, use_scalar): |
| test_conv_fusion(use_bias, module, tracing, pytorch_op, scalar, add_tensor=None, expect_success=True) |
| |
| |
| for use_bias, pytorch_op in product(conv_bias, ops): |
| # broadcasting add |
| test_conv_fusion(use_bias, nn.Conv2d, False, pytorch_op, False, |
| add_tensor=torch.rand(32, 1, 32), expect_success=False) |
| |
| # broadcasting add |
| test_conv_fusion(use_bias, nn.Conv2d, False, pytorch_op, False, add_tensor=torch.rand(1, 1), expect_success=True) |
| |
| # add with different dtype |
| test_conv_fusion(use_bias, nn.Conv2d, False, pytorch_op, False, |
| add_tensor=torch.tensor([2]).to(torch.int), expect_success=True) |
| |
| def test_conv_mul_add_bn(self): |
| class Conv_Mul_Add_Bn(nn.Module): |
| |
| def __init__(self, in_channels, out_channels, **kwargs): |
| super().__init__() |
| self.conv = nn.Conv2d(in_channels, out_channels, **kwargs) |
| self.bn = nn.BatchNorm2d(out_channels, eps=0.001) |
| self.tensor1 = torch.tensor(2.2) |
| self.tensor2 = torch.tensor(2) |
| |
| def forward(self, x): |
| return self.bn(torch.add(torch.mul(self.conv(x), self.tensor1), self.tensor2)) |
| |
| input = torch.randn(8, 3, 64, 64) |
| model = Conv_Mul_Add_Bn(3, 32, kernel_size=3, stride=1).eval() |
| |
| with torch.no_grad(): |
| result = model(input) |
| traced_model = torch.jit.trace(model, input).eval() |
| traced_model = torch.jit.freeze(traced_model) |
| tresult = traced_model(input) |
| self.assertEqual(result, tresult) |
| FileCheck().check("conv").check_not("aten::batch_norm").run(traced_model.graph) |
| FileCheck().check("conv").check_not("aten::add").run(traced_model.graph) |
| |
| def test_linear_bn_folding(self): |
| module_pairs = [(nn.Linear, nn.BatchNorm1d), (nn.Linear, nn.BatchNorm2d), (nn.Linear, nn.BatchNorm3d)] |
| use_tracing = [True, False] |
| bn_running_stats = [True, False] |
| |
| for modules, tracing, track_stats in product(module_pairs, use_tracing, bn_running_stats): |
| class LinearBN(torch.nn.Module): |
| def __init__(self, in_features, out_features): |
| super().__init__() |
| self.linear = modules[0](in_features, out_features) |
| self.bn = modules[1](out_features, eps=0.001, track_running_stats=track_stats) |
| |
| def forward(self, x): |
| x = self.linear(x) |
| return self.bn(x) |
| |
| mod_eager = LinearBN(32, 32).eval() |
| |
| inps = [3, 32] |
| if modules[1] == nn.BatchNorm2d: |
| inps.append(inps[-1]) |
| inps.append(inps[-1]) |
| if modules[1] == nn.BatchNorm3d: |
| inps.append(inps[-1]) |
| inps.append(inps[-1]) |
| inps.append(inps[-1]) |
| |
| inp = torch.rand(inps) |
| |
| if tracing: |
| scripted_mod = torch.jit.trace(mod_eager, (inp)) |
| else: |
| scripted_mod = torch.jit.script(mod_eager) |
| |
| self.run_pass("inline", scripted_mod.graph) |
| self.run_pass("peephole", scripted_mod.graph) |
| self.run_pass("constant_propagation", scripted_mod.graph) |
| |
| FileCheck().check("linear").check("batch").run(scripted_mod.graph) |
| # successfully no-ops with non-const inputs |
| self.run_pass("fold_frozen_linear_bn", scripted_mod.graph) |
| FileCheck().check("linear").check("aten::batch_norm").run(scripted_mod.graph) |
| |
| scripted_mod = torch.jit.freeze(scripted_mod) |
| self.run_pass("fold_frozen_linear_bn", scripted_mod.graph) |
| if track_stats: |
| FileCheck().check("linear").check_not("aten::batch_norm").run(scripted_mod.graph) |
| else: |
| FileCheck().check("linear").check("aten::batch_norm").run(scripted_mod.graph) |
| |
| self.assertEqual(mod_eager(inp), scripted_mod(inp)) |
| self.assertEqual(mod_eager(inp), scripted_mod(inp)) |
| |
| @skipCUDAMemoryLeakCheckIf(True) |
| @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") |
| def test_linear_bn_folding_autocast_scenario_cuda(self): |
| module_pairs = [(nn.Linear, nn.BatchNorm1d), (nn.Linear, nn.BatchNorm2d), (nn.Linear, nn.BatchNorm3d)] |
| use_tracing = [True, False] |
| bn_running_stats = [True, False] |
| |
| for modules, tracing, track_stats in product(module_pairs, use_tracing, bn_running_stats): |
| class LinearBN(torch.nn.Module): |
| def __init__(self, in_features, out_features): |
| super().__init__() |
| self.linear = modules[0](in_features, out_features, bias=False, dtype=torch.half) |
| self.bn = modules[1](out_features, eps=0.001, dtype=torch.float) |
| |
| def forward(self, x): |
| x = self.linear(x) |
| return self.bn(x) |
| |
| mod_eager = LinearBN(32, 32).cuda().eval() |
| |
| inps = [3, 32] |
| if modules[1] == nn.BatchNorm2d: |
| inps.append(inps[-1]) |
| inps.append(inps[-1]) |
| if modules[1] == nn.BatchNorm3d: |
| inps.append(inps[-1]) |
| inps.append(inps[-1]) |
| inps.append(inps[-1]) |
| |
| x = torch.rand(inps, dtype=torch.half).cuda() |
| |
| if tracing: |
| scripted_mod = torch.jit.trace(mod_eager, (x)) |
| else: |
| scripted_mod = torch.jit.script(mod_eager) |
| scripted_mod = torch.jit.freeze(scripted_mod) |
| FileCheck().check("linear").check_not("aten::batch_norm").run(scripted_mod.graph) |
| lin_node = scripted_mod.graph.findNode("aten::linear", True) |
| self.assertTrue(lin_node is not None) |
| weight_input = lin_node.namedInput("weight") |
| bias_input = lin_node.namedInput("bias") |
| self.assertTrue(bias_input is not None) |
| self.assertTrue(weight_input.type().dtype() == torch.half) |
| self.assertTrue(bias_input.type().dtype() == torch.half) |
| |
| self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2) |
| self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2) |
| |
| @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") |
| def test_linear_concat(self): |
| out_dimms = [[5, 10], [1, 5]] |
| |
| for w1_dim, w2_dim in out_dimms: |
| class ModMultLinear(nn.Module): |
| def __init__(self, w1_dim, w2_dim): |
| super().__init__() |
| self.w1 = nn.Parameter(torch.rand([w1_dim, 5])) |
| self.b1 = nn.Parameter(torch.rand([w1_dim])) |
| self.w2 = nn.Parameter(torch.rand([w2_dim, 5])) |
| self.b2 = nn.Parameter(torch.rand([w2_dim])) |
| |
| def forward(self, in_tensor1): |
| res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1) |
| res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b2) |
| return res1, res2 |
| |
| mod_eager = ModMultLinear(w1_dim, w2_dim).eval() |
| |
| test_val1 = torch.rand([50, 5]) |
| self.check_linear_optimizations(mod_eager, 2, 1, (test_val1, )) |
| |
| @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") |
| def test_linear_concat_complex(self): |
| """ |
| Testing that the interleaving of multiple optimizations does not |
| cause errors, and gets optimized as expected |
| """ |
| class ModMultLinear(nn.Module): |
| def __init__(self): |
| super().__init__() |
| w1_dim = 5 |
| w2_dim = 10 |
| self.w1 = nn.Parameter(torch.rand([w1_dim, 5])) |
| self.b1 = nn.Parameter(torch.rand([w1_dim])) |
| self.w2 = nn.Parameter(torch.rand([w2_dim, 5])) |
| self.b2 = nn.Parameter(torch.rand([w2_dim])) |
| |
| def forward(self, in_tensor1): |
| res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1) |
| res3 = torch._C._nn.linear(res1, self.w2, self.b2) |
| res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b2) |
| res4 = torch._C._nn.linear(res1, self.w1, self.b1) |
| return res2, res3, res4 |
| |
| mod_eager = ModMultLinear().eval() |
| test_val1 = torch.rand([50, 5]) |
| self.check_linear_optimizations(mod_eager, 4, 2, (test_val1, )) |
| |
| @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") |
| def test_linear_concat_different_input(self): |
| """ |
| There should be no change to the graph due to the optimization pass |
| due to the two input tensors being different |
| """ |
| |
| # Freezing requires that the graph be a module |
| class ModMultLinear(nn.Module): |
| def __init__(self, w1_dim, w2_dim): |
| super().__init__() |
| self.w1 = nn.Parameter(torch.rand([w1_dim, 5])) |
| self.b1 = nn.Parameter(torch.rand([w1_dim])) |
| self.w2 = nn.Parameter(torch.rand([w2_dim, 5])) |
| self.b2 = nn.Parameter(torch.rand([w2_dim])) |
| |
| def forward(self, in_tensor1, in_tensor2): |
| res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1) |
| res2 = torch._C._nn.linear(in_tensor2, self.w2, self.b2) |
| return res1, res2 |
| |
| mod_eager = ModMultLinear(5, 5).eval() |
| test_val1 = torch.rand([50, 5]) |
| test_val2 = torch.rand([50, 5]) |
| self.check_linear_optimizations(mod_eager, 2, 2, (test_val1, test_val2)) |
| |
| @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") |
| def test_linear_multiple_blocks(self): |
| class ModMultLinear(nn.Module): |
| def __init__(self, w1_dim, w2_dim): |
| super().__init__() |
| self.w1 = nn.Parameter(torch.rand([w1_dim, 5])) |
| self.b1 = nn.Parameter(torch.rand([w1_dim])) |
| self.w2 = nn.Parameter(torch.rand([w2_dim, 5])) |
| self.b2 = nn.Parameter(torch.rand([w2_dim])) |
| |
| def forward(self, in_tensor1, in_tensor2, cond: bool): |
| res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1) |
| if cond: |
| res3 = torch._C._nn.linear(in_tensor2, self.w2, self.b2) |
| res4 = torch._C._nn.linear(in_tensor1, self.w2, self.b1) |
| else: |
| raise AssertionError() |
| res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b1) |
| return res1, res2, res3, res4 |
| |
| mod_eager = ModMultLinear(5, 5).eval() |
| test_val1 = torch.rand([50, 5]) |
| test_val2 = torch.rand([50, 5]) |
| self.check_linear_optimizations(mod_eager, 4, 3, (test_val1, test_val2, True)) |
| |
| def check_linear_optimizations(self, eager_mod, orig_linears, new_linears, test_vals): |
| for is_cuda in [False, True]: |
| if is_cuda: |
| mod_to_device = eager_mod.cuda() |
| test_vals_to_device = [t.cuda() if isinstance(t, torch.Tensor) else t for t in test_vals] |
| else: |
| mod_to_device = eager_mod |
| test_vals_to_device = test_vals |
| |
| script_mod = torch.jit.script(mod_to_device) |
| op_graph = script_mod.graph |
| |
| FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(op_graph) |
| # successively no-ops with non-const inputs |
| self.run_pass("concat_frozen_linear", op_graph) |
| FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(op_graph) |
| |
| script_mod = torch.jit.freeze(script_mod) |
| op_graph = script_mod.graph |
| self.run_pass("concat_frozen_linear", op_graph) |
| if is_cuda: |
| FileCheck().check_count("aten::linear", new_linears, exactly=True).run(op_graph) |
| else: |
| FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(op_graph) |
| |
| self.assertEqual(mod_to_device(*test_vals_to_device), script_mod(*test_vals_to_device)) |
| |
| |
| def test_optimize_freeze_module(self): |
| in_channels, out_channels = 3, 32 |
| conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True) |
| bn = torch.nn.BatchNorm2d(out_channels, eps=.001) |
| mod = torch.nn.Sequential(conv, bn) |
| # set optimize to False here, by default freezing runs run_frozen_optimizations |
| frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize_numerics=False) |
| # inspect frozen mod |
| FileCheck().check("batch_norm").run(frozen_mod.graph) |
| torch.jit.run_frozen_optimizations(frozen_mod) |
| FileCheck().check_not("batch_norm").run(frozen_mod.graph) |
| |
| # run_frozen_optimizations should be run |
| frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval())) |
| FileCheck().check_not("batch_norm").run(frozen_mod.graph) |
| |
| def test_freeze_remove_dropout(self): |
| class Net(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.dropout = nn.Dropout(0.5) |
| |
| def forward(self, x): |
| return self.dropout(x) |
| |
| mod = torch.jit.script(Net()) |
| # inspect mod |
| torch._C._jit_pass_inline(mod.graph) |
| FileCheck().check("aten::dropout").run(mod.graph) |
| frozen_mod = torch.jit.freeze(mod.eval()) |
| FileCheck().check_not("aten::dropout").run(frozen_mod.graph) |
| |
| input = torch.randn(2) |
| output_s = mod.forward(input) |
| output_f = frozen_mod.forward(input) |
| self.assertEqual(output_s, output_f) |
| |
| def test_freeze_remove_feature_dropout(self): |
| class Net(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.dropout = nn.Dropout2d(0.5) |
| |
| def forward(self, x): |
| return self.dropout(x) |
| |
| mod = torch.jit.script(Net().eval()) |
| # inspect mod |
| torch._C._jit_pass_inline(mod.graph) |
| FileCheck().check("aten::feature_dropout").run(mod.graph) |
| frozen_mod = torch.jit.freeze(mod) |
| FileCheck().check_not("aten::feature_dropout").run(frozen_mod.graph) |
| |
| input = torch.randn(2, 2, 1, 1) |
| output_s = mod.forward(input) |
| output_f = frozen_mod.forward(input) |
| self.assertEqual(output_s, output_f) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") |
| def test_freeze_mkdlnn(self): |
| conv = torch.nn.Conv2d(3, 32, kernel_size=3, stride=2).eval().float() |
| convmkl = mkldnn_utils.to_mkldnn(conv) |
| out = torch.jit.freeze(torch.jit.script(convmkl.eval())) |
| inp = torch.rand([4, 3, 4, 4]).float() |
| self.assertEqual(out(inp.to_mkldnn()).to_dense(), conv(inp)) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") |
| def test_conv_to_mkldnn(self): |
| with set_default_dtype(torch.float): |
| for module, trace in product([nn.Conv2d, nn.Conv3d], [False, True]): |
| mod = module(3, 32, kernel_size=3, stride=2).eval() |
| inps = [4, 3, 4] |
| if module == nn.Conv2d: |
| inps.append(inps[-1]) |
| if module == nn.Conv3d: |
| inps.append(inps[-1]) |
| inps.append(inps[-1]) |
| |
| inp = torch.rand(inps) |
| if trace: |
| scripted_mod = torch.jit.script(mod) |
| else: |
| scripted_mod = torch.jit.trace(mod, (inp,)) |
| |
| self.run_pass("inline", scripted_mod.graph) |
| |
| FileCheck().check("conv").run(scripted_mod.graph) |
| # successfully no-ops with non-const inputs |
| self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph) |
| FileCheck().check_not("to_mkldnn").run(scripted_mod.graph) |
| |
| scripted_mod = torch.jit.freeze(scripted_mod) |
| self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph) |
| FileCheck().check("to_mkldnn").check("prim::mkldnn_convolution").check("to_dense").run(scripted_mod.graph) |
| |
| self.assertEqual(mod(inp), scripted_mod(inp)) |
| self.assertEqual(mod(inp), scripted_mod(inp)) |
| |
| def test_linear_transpose(self): |
| class ModLinear(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.bias = torch.nn.Parameter(torch.rand(30)) |
| self.weight = torch.nn.Parameter(torch.rand([30, 20])) |
| |
| def forward(self, x): |
| return torch._C._nn.linear(x, self.weight, self.bias) |
| |
| mod_eager = ModLinear().eval() |
| test_val = torch.rand([50, 20]) |
| self.check_linear_optimizations_2(mod_eager, 1, 0, "transpose_frozen_linear", (test_val,)) |
| |
| def test_linear_non_constant_weight(self): |
| class ModLinear(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.bias = torch.nn.Parameter(torch.rand(30)) |
| |
| def forward(self, x, weight): |
| return torch._C._nn.linear(x, weight, self.bias) |
| |
| mod_eager = ModLinear().eval() |
| test_val = torch.rand([50, 20]) |
| test_weight = torch.rand([30, 20]) |
| self.check_linear_optimizations_2(mod_eager, 1, 1, "transpose_frozen_linear", (test_val, test_weight)) |
| |
| def check_linear_optimizations_2(self, eager_mod, orig_linears, new_linears, opt_pass, test_vals): |
| # TODO: merge with check_linear_optimizations once both diffs land |
| mod_to_device = eager_mod |
| test_vals_to_device = test_vals |
| |
| script_mod = torch.jit.script(mod_to_device) |
| op_graph = script_mod.graph |
| |
| FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(op_graph) |
| # successively no-ops with non-const inputs |
| self.run_pass(opt_pass, op_graph) |
| FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(op_graph) |
| |
| script_mod = torch.jit.freeze(script_mod) |
| op_graph = script_mod.graph |
| self.run_pass(opt_pass, op_graph) |
| FileCheck().check_count("aten::linear", new_linears, exactly=True).run(op_graph) |
| |
| self.assertEqual(mod_to_device(*test_vals_to_device), script_mod(*test_vals_to_device)) |
| |
| @staticmethod |
| def conv(): |
| # Generic composable conv for testing purposes |
| return nn.Conv2d(8, 8, 1) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") |
| def test_collapse_adjacent_conversions(self): |
| |
| with set_default_dtype(torch.float): |
| mod = nn.Sequential(self.conv(), self.conv()).eval() |
| scripted_mod = torch.jit.script(mod) |
| scripted_mod = torch.jit.freeze(scripted_mod) |
| self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph) |
| FileCheck().check("to_mkldnn") \ |
| .check("prim::mkldnn_convolution") \ |
| .check("prim::mkldnn_convolution") \ |
| .check("to_dense") \ |
| .run(scripted_mod.graph) |
| FileCheck().check_count("to_mkldnn", 1, exactly=True).run(scripted_mod.graph) |
| |
| inp = torch.rand([1, 8, 8, 8]) |
| self.assertEqual(scripted_mod(inp), mod(inp)) |
| self.assertEqual(scripted_mod(inp), mod(inp)) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") |
| def test_mkldnn_fuser_broadcasting(self): |
| class Add(nn.Module): |
| def __init__(self, tensor): |
| super().__init__() |
| self.tensor = tensor |
| |
| def forward(self, x): |
| return x + self.tensor |
| |
| with set_default_dtype(torch.float): |
| for add_inp in [8], [8, 8, 1]: |
| mod = nn.Sequential(self.conv(), Add(torch.rand(add_inp))).eval() |
| scripted_mod = torch.jit.script(mod) |
| scripted_mod = torch.jit.freeze(scripted_mod) |
| self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph) |
| FileCheck().check("prim::BroadcastMKLDNNTensors").run(scripted_mod.graph) |
| inp = torch.rand([1, 8, 8, 8]) |
| self.assertEqual(scripted_mod(inp), mod(inp)) |
| self.assertEqual(scripted_mod(inp), mod(inp)) |
| |
| # for good measure, check that broadcasting does not work without this op |
| # so we can remove the op if it ever gets supported |
| with self.assertRaisesRegex(RuntimeError, ""): |
| torch.rand([1, 8, 8, 8]).to_mkldnn() + torch.rand(add_inp).to_mkldnn() |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") |
| def test_mkldnn_inplace_removal(self): |
| class AddMul(nn.Module): |
| def __init__(self, tensor): |
| super().__init__() |
| self.tensor = tensor |
| |
| def forward(self, x): |
| return x.add_(self.tensor).div_(self.tensor) - 4 |
| |
| with set_default_dtype(torch.float): |
| mod = nn.Sequential(self.conv(), AddMul(torch.rand([8]))).eval() |
| scripted_mod = torch.jit.script(mod) |
| scripted_mod = torch.jit.freeze(scripted_mod) |
| self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph) |
| # add gets uninplaced and reinplaced |
| FileCheck().check("aten::to_mkldnn").check("aten::add_").check("aten::div_").run(scripted_mod.graph) |
| inp = torch.rand([1, 8, 8, 8]) |
| self.assertEqual(scripted_mod(inp), mod(inp)) |
| self.assertEqual(scripted_mod(inp), mod(inp)) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") |
| @skipIfNoTorchVision |
| def test_maxpool_mkldnn(self): |
| with set_default_dtype(torch.float): |
| model = torchvision.models.resnet18() |
| sub_model = torch.nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool) |
| mod = torch.jit.freeze(torch.jit.script(sub_model.eval())) |
| N, C, H, W, = 10, 3, 224, 224 |
| inp = torch.randn(N, C, H, W) |
| self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) |
| FileCheck().check("max_pool").check("to_dense").run(mod.graph) |
| FileCheck().check_count("to_dense", 1, exactly=True).run(mod.graph) |
| self.assertEqual(mod(inp), sub_model(inp)) |
| |
| @unittest.skipIf(torch.backends.mkldnn.is_available(), "Testing no mkldnn") |
| def test_conv_to_mkldnn_no_mkldnn(self): |
| # test no error when mkldnn not available |
| with set_default_dtype(torch.float): |
| mod = torch.jit.script(nn.Conv2d(3, 32, kernel_size=3, stride=2).eval()) |
| frozen = torch.jit.freeze(mod) |
| self.run_pass("convert_frozen_ops_to_mkldnn", frozen.graph) |
| inp = torch.rand([4, 3, 4, 4]) |
| self.assertEqual(frozen(inp), mod(inp)) |
| |
| @unittest.skipIf(not (TEST_CUDNN or TEST_WITH_ROCM), "requires CUDNN") |
| def test_freeze_conv_relu_fusion(self): |
| with set_default_dtype(torch.float): |
| conv_bias = [True, False] |
| conv_ops = [nn.Conv2d, nn.Conv3d] |
| add_z = [True, False] |
| use_tracing = [True, False] |
| for use_bias, conv, add_z, tracing in product(conv_bias, conv_ops, add_z, use_tracing): |
| class Net(nn.Module): |
| def __init__(self, in_channels, out_channels, **kwargs): |
| super().__init__() |
| self.conv = conv(in_channels, out_channels, bias=use_bias, **kwargs) |
| self.relu = nn.ReLU(inplace=True) |
| self.add_z = add_z |
| |
| def forward(self, x): |
| z = self.conv(x) |
| out = self.conv(x) |
| if self.add_z: |
| out += z |
| out = self.relu(out) |
| return out |
| |
| mod_eager = Net(3, 6, kernel_size=3, stride=2).eval().cuda() |
| |
| inps = [5, 3, 4, 4] |
| if conv == nn.Conv3d: |
| inps.append(inps[-1]) |
| inp = torch.rand(inps).cuda() |
| |
| if tracing: |
| scripted_mod = torch.jit.trace(mod_eager, (inp)) |
| else: |
| scripted_mod = torch.jit.script(mod_eager) |
| |
| frozen_mod = torch.jit.optimize_for_inference(scripted_mod) |
| if TEST_WITH_ROCM: |
| if add_z: |
| FileCheck().check("aten::miopen_convolution_add_relu").run(frozen_mod.graph) |
| else: |
| FileCheck().check("aten::miopen_convolution_relu").run(frozen_mod.graph) |
| else: |
| if add_z: |
| FileCheck().check("aten::cudnn_convolution_add_relu").run(frozen_mod.graph) |
| else: |
| FileCheck().check("aten::cudnn_convolution_relu").run(frozen_mod.graph) |
| |
| self.assertEqual(mod_eager(inp), frozen_mod(inp)) |
| |
| @unittest.skipIf(not (TEST_CUDNN or TEST_WITH_ROCM), "requires CUDNN") |
| def test_freeze_conv_relu_fusion_not_forward(self): |
| with set_default_dtype(torch.float): |
| class Net(nn.Module): |
| def __init__(self, in_channels, out_channels, **kwargs): |
| super().__init__() |
| self.conv = nn.Conv2d(in_channels, out_channels, bias=None, **kwargs) |
| self.relu = nn.ReLU(inplace=True) |
| |
| def forward(self, x): |
| z = self.conv(x) |
| out = self.conv(x) |
| out = self.relu(out) |
| return out |
| |
| @torch.jit.export |
| def make_prediction(self, x): |
| return self.forward(x) |
| |
| mod_eager = Net(3, 6, kernel_size=3, stride=2).eval().cuda() |
| |
| inps = [5, 3, 4, 4] |
| inp = torch.rand(inps).cuda() |
| |
| scripted_mod = torch.jit.script(mod_eager) |
| |
| frozen_mod = torch.jit.freeze(scripted_mod, preserved_attrs=['make_prediction']) |
| optimized_mod = torch.jit.optimize_for_inference(frozen_mod, other_methods=['make_prediction']) |
| if TEST_WITH_ROCM: |
| FileCheck().check("aten::miopen_convolution_relu").run(optimized_mod.make_prediction.graph) |
| else: |
| FileCheck().check("aten::cudnn_convolution_relu").run(optimized_mod.make_prediction.graph) |
| |
| self.assertEqual(mod_eager.make_prediction(inp), optimized_mod.make_prediction(inp)) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") |
| def test_numel_less_than_size_with_padding(self): |
| |
| with set_default_dtype(torch.float): |
| class MyModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = nn.Conv2d( |
| 1, 2, kernel_size=(2, 4), stride=2, padding=2, dilation=(2, 1), |
| ) |
| |
| def forward(self, i0): |
| x = self.conv1(i0) |
| o0 = torch.max(x, i0) |
| o1 = torch.clip(x, -1.5, 1.5) |
| return o0, o1 |
| |
| |
| i0 = torch.zeros((1, 1, 1, 2), dtype=torch.float32) |
| mod = MyModule() |
| out = mod(i0) |
| |
| exported = torch.jit.trace(mod, [i0]) |
| exported = torch.jit.optimize_for_inference(exported) |
| |
| eout = exported(i0) |
| self.assertTrue(all(torch.allclose(x, y) for x, y in zip(out, eout))) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") |
| def test_incompatible_perf_formats(self): |
| with set_default_dtype(torch.float): |
| class Mod(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 64, 3, 2) |
| self.max_pool = torch.nn.MaxPool2d(111, 111) |
| |
| def forward(self, x): |
| a = self.conv(x) |
| b = self.max_pool(a) |
| return a + b |
| |
| model = Mod() |
| model.eval() |
| mod = torch.jit.freeze(torch.jit.script(model)) |
| N, C, H, W, = 10, 3, 224, 224 |
| inp = torch.randn(N, C, H, W) |
| self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) |
| self.assertEqual(model(inp), mod(inp)) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") |
| def test_pool2d_batchnorm(self): |
| with set_default_dtype(torch.float): |
| |
| pooling_layers = [torch.nn.AdaptiveAvgPool2d(4), |
| # torch.nn.AdaptiveMaxPool2d(4), # return tuples |
| torch.nn.MaxPool2d(4), |
| torch.nn.AvgPool2d(4), |
| torch.nn.BatchNorm2d(64).eval()] |
| |
| for pl in pooling_layers: |
| sub_model = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 2, 2), torch.nn.ReLU(), pl, torch.nn.Hardswish()) |
| sub_model.eval() |
| mod = torch.jit.freeze(torch.jit.script(sub_model)) |
| N, C, H, W, = 10, 3, 224, 224 |
| inp = torch.randn(N, C, H, W) |
| # these two passes needed to remove |
| # a size check in BatchNorm2d |
| removeExceptions(mod.graph) |
| self.run_pass('dce', mod.graph) |
| self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) |
| FileCheck().check("aten::to_dense").check_next("return").run(mod.graph) |
| self.assertEqual(sub_model(inp), mod(inp)) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") |
| def test_pool3d_batchnorm(self): |
| with set_default_dtype(torch.float): |
| |
| pooling_layers = [torch.nn.MaxPool3d(4), |
| # torch.nn.AdaptiveAvgPool3d(4), # no ideep bindings |
| # torch.nn.AdaptiveMaxPool3d(4), # return tuples |
| torch.nn.AvgPool3d(4), |
| torch.nn.BatchNorm3d(64).eval()] |
| |
| for pl in pooling_layers: |
| sub_model = torch.nn.Sequential(torch.nn.Conv3d(3, 64, 2, 2), torch.nn.ReLU(), pl, torch.nn.Hardswish()) |
| sub_model.eval() |
| mod = torch.jit.freeze(torch.jit.script(sub_model)) |
| N, C, H, W, D = 10, 3, 64, 64, 64 |
| inp = torch.randn(N, C, D, H, W) |
| # these two passes needed to remove |
| # a size check in BatchNorm2d |
| removeExceptions(mod.graph) |
| self.run_pass('dce', mod.graph) |
| self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) |
| FileCheck().check("aten::to_dense").check_next("return").run(mod.graph) |
| self.assertEqual(sub_model(inp), mod(inp)) |
| |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") |
| @skipIfNoTorchVision |
| def test_conv_hardswish(self): |
| with set_default_dtype(torch.float): |
| class Clamp(torch.nn.Module): |
| def __init__(self, min_val, max_val, **kwargs): |
| super().__init__() |
| self.min_val = min_val |
| self.max_val = max_val |
| |
| def forward(self, x): |
| return torch.clamp(x, self.min_val, self.max_val) |
| |
| N, C, H, W, = 10, 3, 224, 224 |
| activations = [ |
| torch.nn.Hardswish(), |
| torch.nn.Hardsigmoid(), |
| torch.nn.ReLU6(), |
| torch.nn.Tanh(), |
| torch.nn.Hardtanh(0., 6.), |
| torch.nn.Hardtanh(1., 100.), |
| torch.nn.Hardtanh(-100., -1.), |
| torch.nn.GELU(), |
| Clamp(-100., -1.), |
| Clamp(1., 100.), |
| Clamp(0., 6.), |
| Clamp(-1., 0.), |
| ] |
| |
| model = torchvision.models.resnet18() |
| for activation in activations: |
| sub_model = torch.nn.Sequential(model.conv1, activation) |
| sub_model.eval() |
| mod = torch.jit.freeze(torch.jit.script(sub_model)) |
| inp = torch.randn(N, C, H, W) |
| self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) |
| FileCheck().check_count("aten::to_dense", 1, exactly=True).run(mod.graph) |
| self.assertEqual(sub_model(inp), mod(inp)) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") |
| def test_hardswish_hardsigmoid(self): |
| with set_default_dtype(torch.float): |
| op_map = { |
| 'prim::MKLDNNHardSwish' : F.hardswish, |
| 'prim::MKLDNNHardSigmoid' : F.hardsigmoid, |
| } |
| |
| input_sizes = ([0], [1], [3], [1, 3, 8, 8]) |
| for (mkldnn_opname, aten_op) in op_map.items(): |
| for size in input_sizes: |
| for inplace in (True, False): |
| inplace_str = "_" if inplace else "" |
| inplace_tgt = "%34" if inplace else "%35" |
| graph_str = f"""graph(%input.1 : Tensor): |
| %33 : None = prim::Constant() |
| %34 : Tensor = aten::to_mkldnn(%input.1, %33) |
| %35 : Tensor = {mkldnn_opname}{inplace_str}(%34) |
| return ({inplace_tgt}) |
| """ |
| g = torch._C.parse_ir(graph_str) |
| m = self.createFunctionFromGraph(g) |
| x = torch.rand(size) |
| # `inplace=False` is intentional, otherwise we modify the input |
| # and we aren't testing aten impls anyways |
| self.assertEqual(aten_op(x, inplace=False), m(x).to_dense()) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") |
| def test_scalar_mul(self): |
| with set_default_dtype(torch.float): |
| class Mod(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mod = nn.Conv2d(8, 8, 1, padding=1) |
| |
| def forward(self, x): |
| a1 = self.mod(x) * 4 |
| return a1 * 4 + a1 * 5. |
| |
| mod = Mod().eval() |
| scripted = torch.jit.freeze(torch.jit.script(mod)) |
| optimized = torch.jit.optimize_for_inference(scripted) |
| inp = torch.rand([1, 8, 8, 8]) |
| # a1 cant be inplaced for first use, can for second |
| FileCheck().check("ScalarMul(").check("ScalarMul_").run(optimized.graph) |
| self.assertEqual(optimized(inp), mod(inp)) |
| |
| def test_remove_detach(self): |
| class Mod(nn.Module): |
| def forward(self, x): |
| y = x.detach() |
| return y * y |
| |
| mod = Mod().eval() |
| frozen_mod = torch.jit.freeze(torch.jit.script(mod)) |
| inp = torch.randn((2, 2)) |
| FileCheck().check_not("aten::detach").run(frozen_mod.graph) |
| self.assertEqual(frozen_mod(inp), mod(inp)) |
| |
| def test_remove_detach_not_applied(self): |
| class Mod(nn.Module): |
| def forward(self, x): |
| y = x.detach() |
| return x is y |
| |
| mod = Mod().eval() |
| frozen_mod = torch.jit.freeze(torch.jit.script(mod)) |
| inp = torch.randn((2, 2)) |
| FileCheck().check("aten::detach").run(frozen_mod.graph) |
| self.assertEqual(frozen_mod(inp), mod(inp)) |
| |
| @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") |
| class TestMKLDNNReinplacing(JitTestCase): |
| def setUp(self): |
| super().setUp() |
| self.default_dtype = torch.get_default_dtype() |
| torch.set_default_dtype(torch.float) |
| |
| def tearDown(self): |
| super().tearDown() |
| torch.set_default_dtype(self.default_dtype) |
| |
| def getConv(self): |
| return nn.Conv2d(3, 32, kernel_size=3, stride=2).eval() |
| |
| def getInput(self): |
| return torch.rand([4, 3, 4, 4]) |
| |
| def freezeAndConvert(self, mod): |
| mod = torch.jit.freeze(torch.jit.script(mod.eval())) |
| self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) |
| return mod |
| |
| def checkResults(self, mod1, mod2): |
| inp = self.getInput() |
| self.assertEqual(mod1(inp), mod2(inp)) |
| |
| def test_successful(self): |
| # simple conv-relu |
| |
| mod_eager = nn.Sequential(self.getConv(), nn.Hardswish(), nn.ReLU()) |
| mod = self.freezeAndConvert(mod_eager) |
| FileCheck().check("mkldnn_convolution").check_next("prim::MKLDNNHardSwish_").check_next("aten::relu_").run(mod.graph) |
| self.checkResults(mod_eager, mod) |
| |
| def test_merge_liveness(self): |
| class Mod(nn.Module): |
| def __init__(self, tensor): |
| super().__init__() |
| self.tensor = tensor |
| |
| def forward(self, x): |
| # this mul can be inplaced since x is dead after this use |
| temporary = x * self.tensor |
| # temporary livespan is the return node, |
| # add can not be inplaced |
| return temporary + temporary, temporary |
| |
| mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1]))) |
| mod = self.freezeAndConvert(mod_eager) |
| FileCheck().check("aten::mul_").check_not("aten::add_").run(mod.graph) |
| self.checkResults(mod_eager, mod) |
| |
| def test_always_alive_values(self): |
| class Mod(nn.Module): |
| def __init__(self, tensor): |
| super().__init__() |
| self.tensor = tensor |
| |
| def forward(self, x): |
| # x can't be inplaced because its a return value, |
| # check that the inplacing pass doesnt try to inplace |
| # self.tensor because its always alive |
| return x * self.tensor, x |
| |
| mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1]))) |
| mod = self.freezeAndConvert(mod_eager) |
| FileCheck().check_not("aten::mul_").run(mod.graph) |
| self.checkResults(mod_eager, mod) |
| |
| conv = self.getConv() |
| |
| class Mod(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.tensor = torch.rand([4, 32, 1, 1]) |
| self.conv = conv |
| |
| def forward(self, x): |
| # the shapes dont add up on this just testing a particular pattern |
| conv_output = self.conv(x) |
| return conv_output, self.conv(torch.add(x, x)) |
| |
| mod = self.freezeAndConvert(Mod()) |
| # x is an input to the graph, and so it should not be inplaced |
| # in the torch.add(x, x) call |
| FileCheck().check_not("aten::add_").run(mod.graph) |
| |
| def test_switch_inputs_to_inplace(self): |
| class Mod(nn.Module): |
| def __init__(self, tensor): |
| super().__init__() |
| self.tensor = tensor |
| |
| def forward(self, x): |
| # self.tensor cannot be inplaced, however x can, |
| # and bc add is commutative we can reverse inputs to add_ |
| return self.tensor + x |
| |
| mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1]))) |
| mod = self.freezeAndConvert(mod_eager) |
| FileCheck().check("aten::add_").run(mod.graph) |
| self.checkResults(mod_eager, mod) |