blob: e9317b11412a93823374b77df8df1040130c7610 [file] [log] [blame]
import torch
import torch.nn as nn
import torch.nn.functional as F
import unittest
from torch.testing._internal.jit_utils import JitTestCase
from torch._C import parse_ir
from torch.testing import FileCheck
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
from torch.testing._internal.common_utils import set_default_dtype
from torch.utils import mkldnn as mkldnn_utils
from torch.jit._recursive import wrap_cpp_module
from typing import Any
from itertools import product
import io
try:
import torchvision
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
TEST_CUDA = torch.cuda.is_available()
TEST_ROCM = torch.cuda.is_available() and torch.version.hip is not None
TEST_CUDNN = False
if TEST_CUDA and not TEST_ROCM: # Skip ROCM
torch.ones(1).cuda() # initialize cuda context
TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=torch.device('cuda:0')))
def removeExceptions(graph):
for n in graph.findAllNodes('prim::RaiseException'):
n.destroy()
class TestFreezing(JitTestCase):
def test_freeze_module(self):
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.a = 1 # folded
self.b = 1.2 # folded
self.c = "hello" # folded
self.c2 = "hi\xA1" # not folded
self.d = [1, 1] # folded
self.e = [1.0, 1.1] # folded
self.f = ["hello", "world"] # folded
self.f2 = [(1, "Over \u0e55\u0e57 57")]
self.g = ([1, 2], 3.2, "4.4", torch.tensor([5.5], requires_grad=True)) # folded
self.h = {"layer" : [torch.tensor([7.7], requires_grad=True)]}
self.h2 = {"layer\xB1" : [torch.tensor([8.8], requires_grad=True)]}
self.t = torch.tensor([1.2, 2.4], requires_grad=True) # folded
self.ts = [torch.tensor([1.0, 2.0], requires_grad=True), torch.tensor([3.0, 4.0], requires_grad=True)] # folded
self.tt = [[torch.tensor([3.3, 2.3], requires_grad=True), None]]
def forward(self, x):
return str(self.a) + str(self.b) + self.c + self.c2 + str(self.d) + \
str(self.e) + str(self.f) + str(self.f2) + str(self.g) + \
str(self.h) + str(self.h2) + str(self.t) + str(self.ts) + str(self.tt)
m = torch.jit.script(M())
m.eval()
input = torch.randn(2, 2)
output_s = m.forward(input)
m._c = torch._C._freeze_module(m._c)
buffer = io.BytesIO()
torch.jit.save(m._c, buffer)
buffer.seek(0)
m2 = torch.jit.load(buffer)
# Check if frozen module looks as below:
# module m {
# attributes {
# tt = ...
# }
# ...
# }
self.assertFalse(m2._c.hasattr('a'))
self.assertFalse(m2._c.hasattr('b'))
self.assertFalse(m2._c.hasattr('c'))
self.assertFalse(m2._c.hasattr('c2'))
self.assertFalse(m2._c.hasattr('d'))
self.assertFalse(m2._c.hasattr('e'))
self.assertFalse(m2._c.hasattr('f'))
self.assertFalse(m2._c.hasattr('f2'))
self.assertFalse(m2._c.hasattr('g'))
self.assertFalse(m2._c.hasattr('h'))
self.assertFalse(m2._c.hasattr('h2'))
self.assertFalse(m2._c.hasattr('t'))
self.assertFalse(m2._c.hasattr('ts'))
self.assertFalse(m2._c.hasattr('tt'))
output_f = m2.forward(input)
self.assertEqual(output_s, output_f)
def test_freeze_module_with_submodule(self):
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.a = 11
self.b = 2
def forward(self, x):
return self.a + self.b
class SubModule2(nn.Module):
def __init__(self):
super(SubModule2, self).__init__()
self.a = 12
self.b = 2
def forward(self, x):
self.b = 30
return self.a + self.b
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sub1 = SubModule()
self.sub2 = SubModule2()
self.a = 3
self.b = 4
def forward(self, x):
self.b = 20
return self.sub1(x) + self.a + self.b + self.sub2(x)
m = torch.jit.script(TestModule())
m.eval()
input = torch.randn(2, 2)
output_s = m.forward(input)
mf = torch.jit.freeze(m)
# Check if frozen module looks as below:
# module m {
# attributes {
# sub2 = ...
# b =
# }
# ...
# submodule {
# module m {
# attributes {
# sub2 = ...
# b =
# }
# ...
# }
# }
# }
mf = mf._c
self.assertFalse(mf.hasattr('sub1'))
self.assertFalse(mf.hasattr('a'))
self.assertTrue(mf.hasattr('b'))
self.assertTrue(mf.hasattr('sub2'))
self.assertTrue(mf.sub2.hasattr('b')) # verify b is preserved in sub2
self.assertFalse(mf.sub2.hasattr('a')) # verify a is removed in sub2
output_f = mf.forward(input)
self.assertEqual(output_s, output_f)
def test_freeze_module_with_fork(self):
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.a = torch.ones(20, 20)
self.b = torch.ones(20, 20)
def forward(self, x):
return self.a * self.b + x
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sub = SubModule()
def forward(self, x):
fut = torch.jit._fork(self.sub.forward, x)
y_hat = self.sub(x)
y = torch.jit._wait(fut)
return y_hat + y
m = torch.jit.script(TestModule())
m.eval()
input = torch.randn(20, 20)
output_s = m.forward(input)
mf = torch._C._freeze_module(m._c)
# Check if frozen module looks as below:
# module m {
# attributes {
# }
# ...
# submodule {
# }
# }
self.assertFalse(mf.hasattr('a'))
self.assertFalse(mf.hasattr('b'))
output_f = mf.forward(input)
self.assertEqual(output_s, output_f)
def test_freeze_module_with_nested_fork(self):
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.a = torch.ones(20, 20)
self.b = torch.ones(20, 20)
def forward(self, x):
return self.a * self.b + x
class SubModule2(nn.Module):
def __init__(self):
super(SubModule2, self).__init__()
self.sub = SubModule()
self.c = torch.ones(20, 20)
def forward(self, x):
fut = torch.jit._fork(self.sub.forward, x)
y_hat = self.sub(x)
y = torch.jit._wait(fut)
return y_hat + y + self.c
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sub = SubModule2()
self.d = 1
def forward(self, x):
fut = torch.jit._fork(self.sub.forward, x)
y_hat = self.sub(x)
y = torch.jit._wait(fut)
self.d = 2
return y_hat * y + self.d
m = torch.jit.script(TestModule())
m.eval()
input = torch.randn(20, 20)
output_s = m.forward(input)
mf = torch._C._freeze_module(m._c)
# Check if frozen module looks as below:
# module m {
# attributes {
# }
# ...
# submodule {
# }
# }
self.assertFalse(mf.hasattr('a'))
self.assertFalse(mf.hasattr('b'))
self.assertFalse(mf.hasattr('c'))
self.assertTrue(mf.hasattr('d'))
output_f = mf.forward(input)
self.assertEqual(output_s, output_f)
def test_freeze_module_with_fork2(self):
@torch.jit.script
def foo(x):
return x * 2
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.a = torch.ones(20, 20)
self.b = torch.ones(20, 20)
def forward(self, x):
fut = torch.jit._fork(foo, self.a)
y_hat = foo(self.b)
y = torch.jit._wait(fut)
return y_hat + y
m = torch.jit.script(TestModule())
m.eval()
input = torch.randn(2, 2)
output_s = m.forward(input)
mf = torch._C._freeze_module(m._c)
# Check if frozen module looks as below:
# module m {
# attributes {
# self.a = ...
# self.b = ..
# }
# ...
# submodule {
# }
# }
# TODO: Although there are no mutation, the alias analysis
# conservatively assumes there is a mutation because attributes are
# passed to fork subgraph. both 'a' and 'b' are preserved.
self.assertTrue(mf.hasattr('a'))
self.assertFalse(mf.hasattr('b'))
output_f = mf.forward(input)
self.assertEqual(output_s, output_f)
def test_freeze_module_with_fork_calling_module_method(self):
@torch.jit.script
def foo(x, y):
return x * y
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.a = torch.ones(20, 20)
self.b = torch.ones(20, 20)
@torch.jit.export
def foo(self, x):
return x * self.a
@torch.jit.export
def bar(self, x):
return x * self.b
def forward(self, x):
fut = torch.jit._fork(self.foo, self.b)
y_hat = self.bar(self.a)
y = torch.jit._wait(fut)
return y_hat + y
m = torch.jit.script(TestModule())
m.eval()
input = torch.randn(2, 2)
output_s = m.forward(input)
mf = torch._C._freeze_module(m._c)
# Check if frozen module looks as below:
# module m {
# attributes {
# self.b = ..
# }
# ...
# TODO: Although there are no mutation, the alias analysis
# conservatively assumes there is a mutation because attributes are
# passed to fork subgraph. 'b' is preserved.
self.assertFalse(mf.hasattr('a'))
self.assertTrue(mf.hasattr('b'))
output_f = mf.forward(input)
self.assertEqual(output_s, output_f)
def test_freeze_module_with_sharedclasstype(self):
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.a = torch.tensor([1.1])
self.b = torch.tensor([2.2])
def forward(self, x):
return self.a + self.b
@torch.jit.export
def modify_a(self, x):
self.a[0] += 10
return self. b
@torch.jit.export
def modify_b(self, x):
self.b[0] += 20
return self.a
class SubModule2(nn.Module):
def __init__(self):
super(SubModule2, self).__init__()
self.sub = SubModule()
self.b = torch.tensor([3.3])
def forward(self, x):
y = self.sub.modify_b(x)
return y + self.b
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sub1 = SubModule() # sub1 and sub2.sub shared same class type.
self.sub2 = SubModule2()
self.a = torch.tensor([4.4])
def forward(self, x):
z = self.sub1.modify_a(x)
return self.sub2(x) + z + self.a
m = torch.jit.script(TestModule())
m.eval()
input = torch.randn(2, 2)
output_s = m.forward(input)
mf = torch._C._freeze_module(m._c)
# Checking if Frozen module looks as below
# module mf {
# attributes {
# sub1 = ...
# sub2 = ...
# }
# ...
# submodules {
# module sub1 {
# attributes {
# a = ...
# b = ...
# }
# ...
# }
# module sub2 {
# attributes {
# sub = ...
# }
# ...
# submodule {
# module sub {
# attributes {
# a = ...
# b = ...
# }
# ...
# }
# }
# }
# }
# }
self.assertTrue(mf.hasattr('sub1'))
self.assertTrue(mf.sub1.hasattr('a'))
self.assertTrue(mf.sub1.hasattr('b'))
self.assertFalse(mf.hasattr('a'))
self.assertTrue(mf.hasattr('sub2'))
self.assertTrue(mf.sub2.hasattr('sub'))
self.assertFalse(mf.sub2.hasattr('b'))
self.assertTrue(mf.sub2.sub.hasattr('a'))
self.assertTrue(mf.sub2.sub.hasattr('b'))
output_f = mf.forward(input)
self.assertEqual(output_s, output_f)
def test_freeze_module_with_nestedaliasing(self):
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.a = torch.tensor([1.1])
self.b = torch.tensor([2.2])
def forward(self, x):
return self.a + self.b
@torch.jit.export
def modify_a(self, x):
self.a[0] = 10
return self. b
@torch.jit.export
def modify_b(self, x):
self.b[0] = 20
return self.a
Sub = SubModule()
class SubModule2(nn.Module):
def __init__(self):
super(SubModule2, self).__init__()
self.sub = Sub # aliasing
def forward(self, x):
return self.sub.a
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sub1 = Sub # aliasing
self.sub2 = SubModule2()
def forward(self, x):
z = self.sub1.modify_a(x)
return self.sub2(x) + z
m = torch.jit.script(TestModule())
m.eval()
mf = torch._C._freeze_module(m._c)
self.assertTrue(mf.hasattr('sub1'))
self.assertTrue(mf.sub1.hasattr('a'))
self.assertFalse(mf.sub1.hasattr('b'))
self.assertTrue(mf.hasattr('sub2'))
self.assertTrue(mf.sub2.hasattr('sub'))
self.assertTrue(mf.sub2.sub.hasattr('a')) # Freezing detects that self.sub2.sub.a and self.sub1.a are alias
self.assertFalse(mf.sub2.sub.hasattr('b'))
input = torch.randn(2, 2)
output_s = m.forward(input)
output_f = mf.forward(input)
self.assertEqual(output_s, output_f)
# FIXME: JIT is not honoring aliasing. 'Sub' module is copied. As a result
# Eager and Script modules produce different output.
def test_freeze_module_with_nestedaliasingscalar(self):
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.a = 1.1
self.b = 2.2
def forward(self, x):
return self.a + self.b
@torch.jit.export
def modify_a(self, x):
self.a = 10.0
return self. b
@torch.jit.export
def modify_b(self, x):
self.b = 20.0
return self.a
Sub = SubModule()
class SubModule2(nn.Module):
def __init__(self):
super(SubModule2, self).__init__()
self.sub = Sub # aliasing
def forward(self, x):
return self.sub.a
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sub1 = Sub # aliasing
self.sub2 = SubModule2()
def forward(self, x):
z = self.sub1.modify_a(x)
return self.sub2(x) + z
m = TestModule()
ms = torch.jit.script(m)
ms.eval()
mf = torch._C._freeze_module(ms._c)
self.assertTrue(mf.hasattr('sub1'))
self.assertTrue(mf.sub1.hasattr('a'))
self.assertFalse(mf.sub1.hasattr('b'))
# sub2 is fully folded becasue self.sub1 and self.sub2.sub are not alias (Scripting bug)
self.assertFalse(mf.hasattr('sub2'))
input = torch.randn(2, 2)
output = m.forward(input)
output_s = ms.forward(input)
output_f = mf.forward(input)
# Should be equal
self.assertNotEqual(output, output_s)
self.assertEqual(output_s, output_f)
def test_freeze_module_with_preserve_sub_module(self):
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.a = torch.tensor([1.1])
self.b = 2.2
def forward(self, x):
return self.a
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sub1 = SubModule() # aliasing
self.sub2 = SubModule()
def forward(self, x):
return self.sub2(x) + self.sub1(x)
m = TestModule()
ms = torch.jit.script(m)
ms.eval()
mf = torch._C._freeze_module(ms._c, ["sub1"])
# Test that 'sub1' is preserved entirely and 'sub2' is completely folded
self.assertTrue(mf.hasattr('sub1'))
self.assertTrue(mf.sub1.hasattr('a'))
self.assertTrue(mf.sub1.hasattr('b'))
self.assertFalse(mf.hasattr('sub2'))
input = torch.randn(2, 2)
output_s = ms.forward(input)
output_f = mf.forward(input)
self.assertEqual(output_s, output_f)
def test_freeze_module_with_preserve_sub_module_and_mutation(self):
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.a = torch.tensor([1.1])
self.b = 2.2
def forward(self, x):
self.a[0] = 3.3
return self.a
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sub1 = SubModule() # aliasing
self.sub2 = SubModule()
def forward(self, x):
return self.sub2(x) + self.sub1(x)
m = TestModule()
ms = torch.jit.script(m)
ms.eval()
mf = torch._C._freeze_module(ms._c, ["sub1"])
# Test that be both sub1 and sub1 are preserved and 'b' is preserved
# even if it is not used. To fulfill user request to preserve 'sub1'
self.assertTrue(mf.hasattr('sub1'))
self.assertTrue(mf.sub1.hasattr('a'))
self.assertTrue(mf.sub1.hasattr('b'))
self.assertTrue(mf.hasattr('sub2'))
self.assertTrue(mf.sub2.hasattr('a'))
self.assertTrue(mf.sub2.hasattr('b'))
input = torch.randn(2, 2)
output_s = ms.forward(input)
output_f = mf.forward(input)
self.assertEqual(output_s, output_f)
def test_freeze_module_with_helperfunction(self):
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.a = 11
self.b = 2
def forward(self, x):
return self.a + self.b
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sub = SubModule()
self.a = 3
self.b = 4
def forward(self, x):
self.b = 20
return self._forward(x) + self.a + self.b
def _forward(self, x):
return self.sub(x)
m = torch.jit.script(TestModule())
m.eval()
input = torch.randn(2, 2)
mf = torch._C._freeze_module(m._c)
self.assertFalse(mf.hasattr('sub'))
self.assertFalse(mf.hasattr('a'))
self.assertTrue(mf.hasattr('b'))
with self.assertRaisesRegex(AttributeError, "TestModule \(.*\) does not have a field with name '_forward'"): # noqa: W605
mf._forward(x)
def test_freeze_module_with_inplace_mutable(self):
class FreezeMe(torch.jit.ScriptModule):
def __init__(self):
super(FreezeMe, self).__init__()
self.a = [11, 22]
@torch.jit.script_method
def forward(self, x):
for i in range(3):
self.a.append(i)
return self.a
m = FreezeMe()
m.eval()
m_f = torch._C._freeze_module(m._c)
self.assertTrue(m_f.hasattr('a'))
m.forward(torch.tensor([3]))
out = m_f.forward(torch.tensor([5]))
expected = [11, 22, 0, 1, 2, 0, 1, 2]
self.assertEqual(out, expected)
# Mutable attributes
def test_freeze_module_with_mutable_list(self):
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.a = [1, 2]
def forward(self, x):
return self.a
m = FreezeMe()
m.eval()
m.a.append(3)
m_s = torch.jit.script(m)
v = m_s.a
v.append(4)
m_s.a = v
m_s.eval()
m_f = torch._C._freeze_module(m_s._c)
# Post-freezing mutating m_s.a does not affect m_f (m_f has its own copy).
v = m_s.a
v.append(5)
m_s.a = v
self.assertFalse(m_f.hasattr('a'))
out = m_f.forward(torch.tensor([5]))
expected = [1, 2, 3, 4]
self.assertEqual(out, expected)
def test_freeze_module_with_mutable_dict(self):
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.a = {"layer" : "4"}
def forward(self, x):
return self.a
@torch.jit.export
def modify_a(self, x):
self.a["layer"] = self.a["layer"] + "1"
return self.a
m = FreezeMe()
m.eval()
m.a["layer2"] = "3"
m_s = torch.jit.script(m)
t = torch.tensor(5)
m_s.modify_a(t)
m_s.eval()
m_f = torch._C._freeze_module(m_s._c)
m.a["layer2"] += "2"
m_s.modify_a(t)
self.assertFalse(m_f.hasattr('a'))
out = m_f.forward(t)
expected = {"layer" : "411", "layer2" : "3"}
self.assertEqual(out, expected)
def test_freeze_module_with_mutable_tensor(self):
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.a = torch.tensor([1., 2., 3.])
def forward(self, x):
return self.a
m = FreezeMe()
m_s = torch.jit.script(m)
m_s.a[1] += 3.0
m_s.eval()
m_f = torch._C._freeze_module(m_s._c)
# Post-freezing tensor attribute mutations affect m_f.
# FIXME: deep copy all folded attributes so that m_f has full ownership.
m_s.a[0] += 5.0
self.assertFalse(m_f.hasattr('a'))
out = m_f.forward(torch.tensor([5]))
expected = [6., 5., 3.]
self.assertEqual(out, expected)
def test_freeze_module_with_tuple(self):
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.a = (torch.tensor([1, 2, 3, 4, 5, 6]), "hi")
def forward(self, x):
if (x[0] == 2.0):
self.a[0][0] = 10
return self.a[0].sum()
m = FreezeMe()
m_s = torch.jit.script(m)
m_s.eval()
inp = torch.tensor([2.0])
expected = m_s.forward(inp)
m_s.a[0][0] = 1
m_f = torch._C._freeze_module(m_s._c)
self.assertFalse(m_f.hasattr('a'))
out = m_f.forward(inp)
self.assertEqual(out, expected)
def test_freeze_module_with_tensor(self):
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.a = torch.tensor([1, 2, 3, 4, 5, 6])
def forward(self, x):
x = self.a.view(2, 3)
x[0][0] += 10
return self.a.sum()
m = FreezeMe()
m_s = torch.jit.script(m)
m_s.eval()
inp = torch.tensor([5])
expected = m_s.forward(inp)
m_f = torch._C._freeze_module(m_s._c)
self.assertTrue(m_f.hasattr('a'))
m_f.a[0] -= 10
out = m_f.forward(inp)
self.assertEqual(out, expected)
def test_freeze_module_with_list(self):
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.a = [torch.tensor([1, 2, 3, 4, 5, 6])]
def forward(self, x):
self.a[0][1] += 10
return self.a[0].sum()
m = FreezeMe()
m_s = torch.jit.script(m)
m_s.eval()
inp = torch.tensor([5])
expected = m_s.forward(inp)
m_s.a[0][1] -= 10
m_f = torch._C._freeze_module(m_s._c)
self.assertFalse(m_f.hasattr('a'))
out = m_f.forward(inp)
self.assertEqual(out, expected)
def test_freeze_module_with_aliased_tensor_attr(self):
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.a = torch.tensor([1, 2, 3, 4, 5, 6])
self.b = self.a.view(2, 3)
def forward(self, x):
self.b[1] += 10
return self.a.sum()
m = FreezeMe()
m_s = torch.jit.script(m)
m_s.eval()
m_f = torch._C._freeze_module(m_s._c)
self.assertTrue(m_f.hasattr('a'))
inp = torch.tensor([5])
out = m_f.forward(inp)
expected = torch.tensor(51) # 1+2+3+14+15+16
self.assertEqual(out, expected)
def test_freeze_module_with_aliased_tensor_attr2(self):
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.a = torch.tensor([1, 2, 3, 4, 5, 6])
self.b = {"layer" : ([self.a.view(2, 3), torch.tensor([10])], 20)}
self.c = ([self.a.view(2, 3), torch.tensor([10])], 20)
self.d = (self.a.view(2, 3), 20)
def forward(self, x):
self.d[0][0] += 10
return self.a.sum()
m = FreezeMe()
m_s = torch.jit.script(m)
m_s.eval()
inp = torch.tensor([5])
expected = m_s.forward(inp)
with self.assertRaisesRegex(RuntimeError, "module contains attributes values that overlaps"):
m_f = torch._C._freeze_module(m_s._c)
def test_freeze_module_with_aliased_tensor_attr3(self):
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.a = torch.tensor([1, 2, 3, 4, 5, 6])
self.b = [self.a, torch.tensor([10])]
def forward(self, x):
self.a[1] += 10
return self.b[0].sum()
m = FreezeMe()
m_s = torch.jit.script(m)
m_s.eval()
inp = torch.tensor([5])
expected = m_s.forward(inp)
m_f = torch._C._freeze_module(m_s._c)
self.assertTrue(m_f.hasattr('a'))
self.assertTrue(m_f.hasattr('b'))
out = m_f.forward(inp)
expected += 10 # account for self.a += 10.
self.assertEqual(out, expected)
def test_freeze_module_with_aliased_tensor_attr4(self):
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.a = torch.tensor([1, 2, 3, 4, 5, 6])
self.b = [self.a, torch.tensor([10])]
def forward(self, x):
self.b[0][0] += 10
return self.a.sum()
m = FreezeMe()
m_s = torch.jit.script(m)
m_s.eval()
inp = torch.tensor([5])
expected = m_s.forward(inp)
m_s.a[0] -= 10
with self.assertRaisesRegex(RuntimeError, "module contains attributes values that overlaps"):
m_f = torch._C._freeze_module(m_s._c)
def test_freeze_module_with_overlapping_attrs(self):
a = torch.tensor([1, 2, 3, 4, 5, 6])
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.b = [a.view(3, 2), torch.tensor([10])]
self.c = (20, a.view(2, 3))
def forward(self, x):
self.b[0][0] += 10
return self.c[1].sum()
m = FreezeMe()
m_s = torch.jit.script(m)
m_s.eval()
inp = torch.tensor([5])
expected = m_s.forward(inp)
a[0] -= 10
with self.assertRaisesRegex(RuntimeError, "module contains attributes values that overlaps"):
m_f = torch._C._freeze_module(m_s._c)
def test_freeze_module_with_aliased_attr(self):
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.a = [1, 2, 3, 4, 5, 6]
self.b = self.a
self.c = (self.a, 10)
def forward(self, x):
self.b[1] += 10
return str(self.a) + str(self.c)
m = FreezeMe()
m_s = torch.jit.script(m)
m_s.eval()
m_f = torch._C._freeze_module(m_s._c)
# FIXME: It should be assertTrue. Currently scripting is making a copy for setting self.b (see #33034)
self.assertFalse(m_f.hasattr('a'))
self.assertFalse(m_f.hasattr('c'))
inp = torch.tensor([5])
out = m_f.forward(inp)
expected = m_s.forward(inp)
self.assertEqual(out, expected)
# Check attribute a is preserved. Alias analysis detects that 'a' has output writers.
# In this example, 'a' is not mutated. However, we do not track which sub
# values of a composite ivalue is mutated.
def test_freeze_module_with_aliased_attr2(self):
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.a = [1, 2, 3, 4, 5, 6]
self.b = ([11], [10])
def forward(self, x):
v = self.a
self.b = (v, [12])
v2 = self.b[1]
v2.append(7)
return str(v) + str(v2)
m = FreezeMe()
m_s = torch.jit.script(m)
m_s.eval()
m_f = torch._C._freeze_module(m_s._c)
self.assertTrue(m_f.hasattr('a'))
inp = torch.tensor([5])
out = m_f.forward(inp)
expected = m.forward(inp)
self.assertEqual(out, expected)
def test_freeze_module_with_aliased_attr3(self):
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.a = [1, 2, 3, 4, 5, 6]
self.b = ([11], [10])
def forward(self, x):
v = self.a
v2 = (v, [12])
v3 = v2[0]
v3.append(7)
return str(self.a)
m = FreezeMe()
m_s = torch.jit.script(m)
m_s.eval()
m_f = torch._C._freeze_module(m_s._c)
self.assertTrue(m_f.hasattr('a'))
inp = torch.tensor([5])
out = m_f.forward(inp)
expected = m.forward(inp)
self.assertEqual(out, expected)
def test_freeze_module_return_self(self):
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.a = torch.tensor([1., 2., 3.])
def forward(self, x):
return self
m = FreezeMe()
m_s = torch.jit.script(m)
m_s.eval()
with self.assertRaisesRegex(RuntimeError, "attempted to freeze a module that return itself"):
m_f = torch._C._freeze_module(m_s._c)
def test_freeze_module_return_sub_module(self):
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
def forward(self, x):
return self.conv1
m = FreezeMe()
m_s = torch.jit.script(m)
m_s.eval()
m_f = torch._C._freeze_module(m_s._c)
self.assertTrue(m_f.hasattr('conv1'))
def test_freeze_module_no_forward(self):
class FreezeMe(nn.Module):
def __init__(self):
super(FreezeMe, self).__init__()
self.lin = nn.Linear(10, 1)
@torch.jit.export
def foo(self, x):
return self.lin(x)
m = FreezeMe()
m_s = torch.jit.script(m)
m_s.eval()
m_f = torch._C._freeze_module(m_s._c, preservedAttrs=['foo'])
input = torch.ones(10)
self.assertEqual(m_s.foo(input), m_f.foo(input))
def test_freeze_module_in_training_mode(self):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = nn.functional.log_softmax(x, dim=1)
return output
model = torch.jit.script(Net())
model.train()
mTrain_freezed = torch._C._freeze_module(model._c)
# verify mTrain_freezed looks exactly as:
# module {
# attributes {
# conv1 = ...
# conv2 = ...
# dropout1 = ...
# dropout2 = ...
# fc1 = ...
# fc2 = ...
# }
# ...
# submodules {
# module conv1 {
# attributes {
# weight = ...
# bias = ...
# }
# ...
# }
# module conv2 {
# attributes {
# weight = ...
# bias = ...
# }
# ...
# }
# module dropout1 {
# attributes {
# training = ...
# }
# ...
# }
# module dropout2 {
# attributes {
# training = ...
# }
# ...
# }
# module fc1 {
# attributes {
# weight = ...
# bias = ...
# }
# ...
# }
# module fc2 {
# attributes {
# weight = ...
# bias = ...
# }
# ...
# }
self.assertFalse(mTrain_freezed.hasattr('training'))
self.assertTrue(mTrain_freezed.hasattr('conv1'))
self.assertFalse(mTrain_freezed.conv1.hasattr('training'))
self.assertTrue(mTrain_freezed.conv1.hasattr('weight'))
self.assertTrue(mTrain_freezed.conv1.hasattr('bias'))
self.assertTrue(mTrain_freezed.hasattr('conv2'))
self.assertFalse(mTrain_freezed.conv2.hasattr('training'))
self.assertTrue(mTrain_freezed.conv2.hasattr('weight'))
self.assertTrue(mTrain_freezed.conv2.hasattr('bias'))
self.assertTrue(mTrain_freezed.hasattr('dropout1'))
self.assertTrue(mTrain_freezed.dropout1.hasattr('training'))
self.assertTrue(mTrain_freezed.hasattr('dropout2'))
self.assertTrue(mTrain_freezed.dropout2.hasattr('training'))
self.assertTrue(mTrain_freezed.hasattr('fc1'))
self.assertTrue(mTrain_freezed.fc1.hasattr('weight'))
self.assertTrue(mTrain_freezed.fc1.hasattr('bias'))
self.assertTrue(mTrain_freezed.hasattr('fc2'))
self.assertTrue(mTrain_freezed.fc2.hasattr('weight'))
self.assertTrue(mTrain_freezed.fc2.hasattr('bias'))
model.eval()
mEval_freezed = torch._C._freeze_module(model._c)
self.assertFalse(mEval_freezed.hasattr('conv1'))
self.assertFalse(mEval_freezed.hasattr('conv2'))
self.assertFalse(mEval_freezed.hasattr('dropout1'))
self.assertFalse(mEval_freezed.hasattr('training'))
self.assertFalse(mEval_freezed.hasattr('fc1'))
self.assertFalse(mEval_freezed.hasattr('dropout2'))
self.assertFalse(mEval_freezed.hasattr('fc2'))
with self.assertRaisesRegex(AttributeError, "does not have a field with name 'state_dict'"):
print(mEval_freezed.state_dict())
buffer = io.BytesIO()
torch.jit.save(mEval_freezed, buffer)
buffer.seek(0)
m = torch.jit.load(buffer)
FileCheck().check_not('GetAttr[name=') \
.run(m._c._get_method('forward').graph)
m2 = torch._C._freeze_module(model._c, preserveParameters=True)
self.assertTrue(m2.hasattr('conv1'))
self.assertTrue(m2.hasattr('conv2'))
self.assertFalse(m2.hasattr('dropout1'))
self.assertFalse(m2.hasattr('training'))
self.assertTrue(m2.hasattr('fc1'))
self.assertFalse(m2.hasattr('dropout2'))
self.assertTrue(m2.hasattr('fc2'))
def test_freeze_module_detach_gradient(self):
mod = nn.Conv2d(8, 3, 4, 2, 1)
self.assertTrue(mod.weight.requires_grad)
smod = torch.jit.script(mod)
smod.eval()
fmod = torch._C._freeze_module(smod._c)
self.assertTrue(mod.weight.requires_grad)
self.assertTrue(smod.weight.requires_grad)
self.assertFalse(fmod.hasattr('weight'))
inp = torch.ones(1, 8, 32, 32)
out1 = fmod.forward(inp)
# FIXME: frozen module mutated from outside (original module).
with torch.no_grad():
smod.weight[0, 0, 0, 0] += 100.0
out2 = fmod.forward(inp)
out3 = smod(inp)
self.assertNotEqual(out1, out2)
self.assertEqual(out2, out3)
def test_freeze_module_with_user_preserved_attr(self):
class Module(nn.Module):
def __init__(self):
super(Module, self).__init__()
self.a = torch.tensor([1.1])
self.b = torch.tensor([2.2])
def forward(self, x):
return self.a + self.b
m = torch.jit.script(Module())
m.eval()
fm = torch._C._freeze_module(m._c, ["a"])
# Attribute "a" is preserved
self.assertTrue(fm.hasattr("a"))
self.assertFalse(fm.hasattr("b"))
def test_freeze_module_with_user_preserved_method(self):
class Module(nn.Module):
def __init__(self):
super(Module, self).__init__()
self.a = torch.tensor([1.1])
self.b = torch.tensor([2.2])
def forward(self, x):
return self.a + self.b
@torch.jit.export
def modify_a(self, x):
self.a[0] += 10
return self.b
@torch.jit.export
def modify_b(self, x):
self.b[0] += 20
return self.a
m = torch.jit.script(Module())
m.eval()
fm = torch._C._freeze_module(m._c, ["modify_a"])
# Both attribute "a" and method "modify_a" are preserved
self.assertTrue(fm.hasattr("a"))
self.assertFalse(fm.hasattr("b"))
input = torch.randn(2, 2)
expected = m.forward(input)
out = fm.forward(input)
self.assertEqual(out, expected)
def test_freeze_module_with_user_preserved_method2(self):
class Module(nn.Module):
def __init__(self):
super(Module, self).__init__()
self.a = torch.tensor([1.1])
self.b = torch.tensor([2.2])
def forward(self, x):
self.b += 10
return self.a + self.b
@torch.jit.export
def modify_a(self, x):
self.a[0] += 10
return self.b + self.a
m = torch.jit.script(Module())
m.eval()
fm = torch._C._freeze_module(m._c, ["modify_a"])
FileCheck().check('prim::GetAttr[name="a"]').run(fm.forward.graph)
FileCheck().check('prim::GetAttr[name="b"]').run(fm.modify_a.graph)
@skipIfNoFBGEMM
def test_module_with_shared_type_instances(self):
class Child(nn.Module):
def __init__(self):
super(Child, self).__init__()
self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32)
def forward(self, x):
x = self.conv1(x)
return x
class Parent(nn.Module):
def __init__(self):
super(Parent, self).__init__()
self.quant = torch.quantization.QuantStub()
self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32)
self.child = Child()
self.child2 = Child()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv1(x)
x = self.child(x)
x = self.child2(x)
x = self.dequant(x)
return x
def _static_quant(model):
qModel = torch.quantization.QuantWrapper(model)
qModel.qconfig = torch.quantization.default_qconfig
torch.quantization.prepare(qModel, inplace=True)
qModel(torch.rand(4, 1, 4, 4, dtype=torch.float32))
torch.quantization.convert(qModel, inplace=True)
return model
with override_quantized_engine('fbgemm'):
data = torch.randn(4, 1, 4, 4, dtype=torch.float32)
m = Parent().to(torch.float32)
m = _static_quant(m)
m = torch.jit.script(m)
m.eval()
torch._C._jit_pass_inline(m.graph)
m_frozen = wrap_cpp_module(torch._C._freeze_module(m._c))
# Earlier bug resulted in _packed_params set to false.
FileCheck().check_not('_packed_params = False').run(m_frozen._c.dump_to_str(True, True, False))
m_res = m(data)
# It used to segfault while running frozen module.
m_frozen_res = m_frozen(data)
self.assertEqual(m_res, m_frozen_res)
def test_module_getattr_indirection(self):
@torch.jit.script
class ValHolder(object):
def __init__(self, val: int):
self.val: int = val
class Mod(nn.Module):
def __init__(self):
super(Mod, self).__init__()
self.mod1 = ValHolder(1)
self.mod2 = ValHolder(2)
def forward(self, cond: bool):
if cond:
mod = self.mod1
else:
mod = self.mod2
return mod.val
mod = Mod()
mod.eval()
frozen_mod = torch.jit.freeze(torch.jit.script(mod))
mod_eager = Mod()
self.assertEqual(mod_eager(True), frozen_mod(True))
self.assertEqual(mod_eager(False), frozen_mod(False))
def test_freeze_module_with_non_static_module_container_index(self):
"""
Test that Modules containing non-static ModuleDict or ModuleList
indexing cannot be frozen.
"""
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def forward(self, inp: Any) -> Any:
pass
class ImplementsInterface(torch.nn.Module):
def forward(self, inp: Any) -> Any:
if isinstance(inp, torch.Tensor):
return torch.max(inp, dim=0)
return inp
class ModWithDict(torch.nn.Module):
def __init__(self):
super().__init__()
self.d = torch.nn.ModuleDict({"module": ImplementsInterface()})
def forward(self, x: torch.Tensor, key: str) -> Any:
value: ModuleInterface = self.d[key]
return value.forward(x)
m = torch.jit.script(ModWithDict())
m.eval()
with self.assertRaisesRegex(RuntimeError, "Freezing modules containing prim::ModuleContainerIndex is not supported"):
mf = torch._C._freeze_module(m._c)
class ModWithList(torch.nn.Module):
def __init__(self):
super().__init__()
self.l = torch.nn.ModuleList([ImplementsInterface()])
def forward(self, x: torch.Tensor, idx: int) -> Any:
value: ModuleInterface = self.l[idx]
return value.forward(x)
m = torch.jit.script(ModWithList())
m.eval()
with self.assertRaisesRegex(RuntimeError, "Freezing modules containing prim::ModuleContainerIndex is not supported"):
mf = torch._C._freeze_module(m._c)
def test_freeze_non_module_class_getattr(self):
class BoxCoder(object):
def __init__(self, bbox_xform_clip):
# type: (float) -> None
self.bbox_xform_clip = bbox_xform_clip
def decode(self, input):
return input * self.bbox_xform_clip
class MyModule(torch.nn.Module):
__annotations__ = {
'box_coder': BoxCoder,
}
def __init__(self):
super(MyModule, self).__init__()
self.box_coder = BoxCoder(50.)
def forward(self, input):
return self.box_coder.decode(input)
model = MyModule()
model.eval()
script_model = torch.jit.freeze(torch.jit.script(model))
inp = torch.randn([4, 4])
output_eager = model(inp)
self.assertEqual(model(inp), script_model(inp))
FileCheck().check_not("GetAttr").run(script_model.graph)
def test_freeze_module_with_tupleoutput_submodule(self):
class SubModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return (x + 1, x + 2)
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.sub = SubModule()
def forward(self, x):
y1, y2 = self.sub(x)
return y1 + y2
m = torch.jit.script(TestModule())
m = m.eval()
mf = torch.jit.freeze(m)
inp = torch.randn(2, 2)
expected = m.forward(inp)
output = mf.forward(inp)
# Check if prim::TupleConstruct and prim::TupleUnpack
# Don't exist in frozen graph
FileCheck().check_not("prim::TupleConstruct").run(mf.graph)
FileCheck().check_not("prim::TupleUnpack").run(mf.graph)
self.assertEqual(output, expected)
class TestFrozenOptimizations(JitTestCase):
def setUp(self):
self.default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.double)
def tearDown(self):
torch.set_default_dtype(self.default_dtype)
def test_conv_bn_folding(self):
conv_bias = [True, False]
module_pairs = [(nn.Conv1d, nn.BatchNorm1d), (nn.Conv2d, nn.BatchNorm2d), (nn.Conv3d, nn.BatchNorm3d)]
use_tracing = [True, False]
for use_bias, modules, tracing in product(conv_bias, module_pairs, use_tracing):
class ConvBN(torch.nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(ConvBN, self).__init__()
self.conv = modules[0](in_channels, out_channels, bias=use_bias, **kwargs)
self.bn = modules[1](out_channels, eps=0.001)
def forward(self, x):
x = self.conv(x)
return self.bn(x)
mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval()
inps = [4, 3, 4]
if modules[0] == nn.Conv2d:
inps.append(inps[-1])
if modules[0] == nn.Conv3d:
inps.append(inps[-1])
inps.append(inps[-1])
inp = torch.rand(inps)
if tracing:
scripted_mod = torch.jit.trace(mod_eager, (inp))
else:
scripted_mod = torch.jit.script(mod_eager)
self.run_pass("inline", scripted_mod.graph)
self.run_pass("peephole", scripted_mod.graph)
self.run_pass("constant_propagation", scripted_mod.graph)
FileCheck().check("conv").check("batch").run(scripted_mod.graph)
# successfully no-ops with non-const inputs
self.run_pass("fold_frozen_conv_bn", scripted_mod.graph)
FileCheck().check("conv").check("aten::batch_norm").run(scripted_mod.graph)
scripted_mod = torch.jit.freeze(scripted_mod)
self.run_pass("fold_frozen_conv_bn", scripted_mod.graph)
FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.graph)
self.assertEqual(mod_eager(inp), scripted_mod(inp))
self.assertEqual(mod_eager(inp), scripted_mod(inp))
def test_conv_add_folding(self):
@torch.no_grad()
def test_conv_fusion(use_bias, module, tracing, op, scalar, add_tensor, expect_success):
class ConvOp(torch.nn.Module):
__constants__ = ['use_scalar']
def __init__(self, in_channels, out_channels, tensor=None, **kwargs):
super(ConvOp, self).__init__()
self.conv = module(in_channels, out_channels, bias=use_bias, **kwargs)
self.conv2 = module(in_channels, out_channels, bias=use_bias, **kwargs)
self.use_scalar = scalar
tensor_size = [1 for _ in range(self.conv.weight.ndim)]
tensor_size[1] = self.conv.weight.size(0)
self.tensor = add_tensor if add_tensor is not None else torch.rand(tensor_size)
self.op = op
def forward(self, x):
x = self.conv(x)
if self.use_scalar:
return self.op(x, 2.)
else:
return self.op(x, self.tensor)
mod_eager = ConvOp(3, 32, kernel_size=3, stride=2).eval()
inps = [4, 3, 4]
if module == nn.Conv2d:
inps.append(inps[-1])
if module == nn.Conv3d:
inps.append(inps[-1])
inps.append(inps[-1])
inp = torch.rand(inps)
if tracing:
scripted_mod = torch.jit.trace(mod_eager, (inp,))
else:
scripted_mod = torch.jit.script(mod_eager)
self.run_pass("inline", scripted_mod.graph)
op_str = "aten::" + op.__name__
FileCheck().check("conv").check(op_str).run(scripted_mod.graph)
# successively no-ops with non-const inputs
self.run_pass("fold_frozen_conv_mul_or_div", scripted_mod.graph)
self.run_pass("fold_frozen_conv_add_or_sub", scripted_mod.graph)
FileCheck().check("conv").check(op_str).run(scripted_mod.graph)
scripted_mod = torch.jit.freeze(scripted_mod)
self.run_pass("fold_frozen_conv_mul_or_div", scripted_mod.graph)
self.run_pass("fold_frozen_conv_add_or_sub", scripted_mod.graph)
if expect_success:
FileCheck().check("conv").check_not(op_str).run(scripted_mod.graph)
else:
FileCheck().check("conv").check(op_str).run(scripted_mod.graph)
self.assertEqual(mod_eager(inp), scripted_mod(inp))
self.assertEqual(mod_eager(inp), scripted_mod(inp))
conv_bias = [True, False]
modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
use_tracing = [False, True]
use_scalar = [False, True]
ops = [torch.add, torch.sub, torch.mul, torch.div]
for use_bias, module, tracing, pytorch_op, scalar in product(conv_bias, modules, use_tracing, ops, use_scalar):
test_conv_fusion(use_bias, module, tracing, pytorch_op, scalar, add_tensor=None, expect_success=True)
for use_bias, pytorch_op in product(conv_bias, ops):
# broadcasting add
test_conv_fusion(use_bias, nn.Conv2d, False, pytorch_op, False,
add_tensor=torch.rand(32, 1, 32), expect_success=False)
# broadcasting add
test_conv_fusion(use_bias, nn.Conv2d, False, pytorch_op, False, add_tensor=torch.rand(1, 1), expect_success=True)
# add with different dtype
test_conv_fusion(use_bias, nn.Conv2d, False, pytorch_op, False,
add_tensor=torch.rand(1).to(torch.int), expect_success=False)
def test_optimize_freeze_module(self):
in_channels, out_channels = 3, 32
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
mod = torch.nn.Sequential(conv, bn)
# set optimize to False here, by default freezing runs run_frozen_optimizations
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize_numerics=False)
# inspect frozen mod
FileCheck().check("batch_norm").run(frozen_mod.graph)
torch.jit.run_frozen_optimizations(frozen_mod)
FileCheck().check_not("batch_norm").run(frozen_mod.graph)
# run_frozen_optimizations should be run
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
FileCheck().check_not("batch_norm").run(frozen_mod.graph)
def test_freeze_remove_dropout(self):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.dropout = nn.Dropout(0.5)
def forward(self, x):
return self.dropout(x)
mod = torch.jit.script(Net())
# inspect mod
torch._C._jit_pass_inline(mod.graph)
FileCheck().check("aten::dropout").run(mod.graph)
frozen_mod = torch.jit.freeze(mod.eval())
FileCheck().check_not("aten::dropout").run(frozen_mod.graph)
input = torch.randn(2)
output_s = mod.forward(input)
output_f = frozen_mod.forward(input)
self.assertEqual(output_s, output_f)
def test_freeze_remove_feature_dropout(self):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.dropout = nn.Dropout2d(0.5)
def forward(self, x):
return self.dropout(x)
mod = torch.jit.script(Net().eval())
# inspect mod
torch._C._jit_pass_inline(mod.graph)
FileCheck().check("aten::feature_dropout").run(mod.graph)
frozen_mod = torch.jit.freeze(mod)
FileCheck().check_not("aten::feature_dropout").run(frozen_mod.graph)
input = torch.randn(2, 2)
output_s = mod.forward(input)
output_f = frozen_mod.forward(input)
self.assertEqual(output_s, output_f)
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
def test_freeze_mkdlnn(self):
conv = torch.nn.Conv2d(3, 32, kernel_size=3, stride=2).eval().float()
convmkl = mkldnn_utils.to_mkldnn(conv)
out = torch.jit.freeze(torch.jit.script(convmkl.eval()))
inp = torch.rand([4, 3, 4, 4]).float()
self.assertEqual(out(inp.to_mkldnn()).to_dense(), conv(inp))
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
def test_conv_to_mkldnn(self):
with set_default_dtype(torch.float):
for module, trace in product([nn.Conv2d, nn.Conv3d], [False, True]):
mod = module(3, 32, kernel_size=3, stride=2).eval()
inps = [4, 3, 4]
if module == nn.Conv2d:
inps.append(inps[-1])
if module == nn.Conv3d:
inps.append(inps[-1])
inps.append(inps[-1])
inp = torch.rand(inps)
if trace:
scripted_mod = torch.jit.script(mod)
else:
scripted_mod = torch.jit.trace(mod, (inp,))
self.run_pass("inline", scripted_mod.graph)
FileCheck().check("conv").run(scripted_mod.graph)
# successfully no-ops with non-const inputs
self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
FileCheck().check_not("to_mkldnn").run(scripted_mod.graph)
scripted_mod = torch.jit.freeze(scripted_mod)
self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
FileCheck().check("to_mkldnn").check("prim::mkldnn_convolution").check("to_dense").run(scripted_mod.graph)
self.assertEqual(mod(inp), scripted_mod(inp))
self.assertEqual(mod(inp), scripted_mod(inp))
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
def test_linear_to_mkldnn(self):
with set_default_dtype(torch.float):
# make sure mkldnn handles broadcast rules
inp_shapes = [[20], [20, 20], [1, 20, 20]]
for inp_shape in inp_shapes:
mod = nn.Linear(20, 30).eval()
scripted_mod = torch.jit.script(mod)
inp = torch.rand(inp_shape)
self.run_pass("inline", scripted_mod.graph)
FileCheck().check("aten::linear").run(scripted_mod.graph)
# successfully no-ops with non-const inputs
self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
FileCheck().check_not("ConvertToMKLDNN").run(scripted_mod.graph)
scripted_mod = torch.jit.freeze(scripted_mod)
self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
FileCheck().check("to_mkldnn").check("aten::linear").check("to_dense").run(scripted_mod.graph)
self.assertEqual(mod(inp), scripted_mod(inp))
self.assertEqual(mod(inp), scripted_mod(inp))
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
def test_collapse_adjacent_conversions(self):
with set_default_dtype(torch.float):
mod = nn.Sequential(nn.Linear(20, 20), nn.Linear(20, 20)).eval()
scripted_mod = torch.jit.script(mod)
scripted_mod = torch.jit.freeze(scripted_mod)
self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
FileCheck().check("to_mkldnn").check("aten::linear").check("aten::linear").check("to_dense").run(scripted_mod.graph)
FileCheck().check_count("to_mkldnn", 1, exactly=True).run(scripted_mod.graph)
inp = torch.rand([20, 20])
self.assertEqual(scripted_mod(inp), mod(inp))
self.assertEqual(scripted_mod(inp), mod(inp))
# testing unsupported behavior
class Add(nn.Module):
def __init__(self, tensor):
super().__init__()
self.tensor = tensor
def forward(self, x):
return x + self.tensor
def test_unsupported(module, preserved_attrs=None):
mod = torch.jit.freeze(torch.jit.script(module.eval()), preserved_attrs)
self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
FileCheck().check("to_mkldnn").check("linear").check("to_dense").check("add").run(mod.graph)
lin = nn.Linear(20, 20)
# Scalar-Tensor not supported
test_unsupported(nn.Sequential(lin, Add(.5)))
# # 0-dim not supported
test_unsupported(nn.Sequential(lin, Add(torch.tensor(.5))))
# tensor of unknown dtype (getAttr node here) not supported
test_unsupported(nn.Sequential(lin, Add(torch.tensor([20]))), ['1'])
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
def test_mkldnn_fuser_broadcasting(self):
class Add(nn.Module):
def __init__(self, tensor):
super().__init__()
self.tensor = tensor
def forward(self, x):
return x + self.tensor
with set_default_dtype(torch.float):
for add_inp in [20], [20, 20, 1]:
mod = nn.Sequential(nn.Linear(20, 20), Add(torch.rand(add_inp))).eval()
scripted_mod = torch.jit.script(mod)
scripted_mod = torch.jit.freeze(scripted_mod)
self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
FileCheck().check("prim::BroadcastMKLDNNTensors").run(scripted_mod.graph)
inp = torch.rand([20, 20])
self.assertEqual(scripted_mod(inp), mod(inp))
self.assertEqual(scripted_mod(inp), mod(inp))
# for good measure, check that broadcasting does not work without this op
# so we can remove the op if it ever gets supported
with self.assertRaisesRegex(RuntimeError, ""):
torch.rand([20, 20]).to_mkldnn() + torch.rand(add_inp).to_mkldnn()
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
def test_mkldnn_inplace_removal(self):
class AddMul(nn.Module):
def __init__(self, tensor):
super().__init__()
self.tensor = tensor
def forward(self, x):
return x.add_(self.tensor).div_(self.tensor) - 4
with set_default_dtype(torch.float):
mod = nn.Sequential(nn.Linear(20, 20), AddMul(torch.rand([20]))).eval()
scripted_mod = torch.jit.script(mod)
scripted_mod = torch.jit.freeze(scripted_mod)
self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
# add gets uninplaced and reinplaced
FileCheck().check("aten::to_mkldnn").check("aten::add_").check("aten::div_").run(scripted_mod.graph)
inp = torch.rand([20, 20])
self.assertEqual(scripted_mod(inp), mod(inp))
self.assertEqual(scripted_mod(inp), mod(inp))
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
@skipIfNoTorchVision
def test_maxpool_mkldnn(self):
with set_default_dtype(torch.float):
model = torchvision.models.resnet18()
sub_model = torch.nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool)
mod = torch.jit.freeze(torch.jit.script(sub_model.eval()))
N, C, H, W, = 10, 3, 224, 224
inp = torch.randn(N, C, H, W)
self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
FileCheck().check("max_pool").check("to_dense").run(mod.graph)
FileCheck().check_count("to_dense", 1, exactly=True).run(mod.graph)
self.assertEqual(mod(inp), sub_model(inp))
@unittest.skipIf(torch._C.has_mkldnn, "Testing no mkldnn")
def test_conv_to_mkldnn_no_mkldnn(self):
# test no error when mkldnn not available
with set_default_dtype(torch.float):
mod = torch.jit.script(nn.Conv2d(3, 32, kernel_size=3, stride=2).eval())
frozen = torch.jit.freeze(mod)
self.run_pass("convert_frozen_ops_to_mkldnn", frozen.graph)
inp = torch.rand([4, 3, 4, 4])
self.assertEqual(frozen(inp), mod(inp))
@unittest.skipIf(not TEST_CUDNN, "requires CUDNN")
def test_freeze_conv_relu_fusion(self):
conv_bias = [True, False]
conv_ops = [nn.Conv2d, nn.Conv3d]
add_z = [True, False]
use_tracing = [True, False]
for use_bias, conv, add_z, tracing in product(conv_bias, conv_ops, add_z, use_tracing):
class Net(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(Net, self).__init__()
self.conv = conv(in_channels, out_channels, bias=use_bias, **kwargs)
self.relu = nn.ReLU(inplace=True)
self.add_z = add_z
def forward(self, x):
z = self.conv(x)
out = self.conv(x)
if self.add_z:
out += z
out = self.relu(out)
return out
mod_eager = Net(3, 6, kernel_size=3, stride=2).eval().cuda()
inps = [5, 3, 4, 4]
if conv == nn.Conv3d:
inps.append(inps[-1])
inp = torch.rand(inps).cuda()
if tracing:
scripted_mod = torch.jit.trace(mod_eager, (inp))
else:
scripted_mod = torch.jit.script(mod_eager)
frozen_mod = torch.jit.optimize_for_inference(scripted_mod)
if add_z:
FileCheck().check("aten::cudnn_convolution_add_relu").run(frozen_mod.graph)
else:
FileCheck().check("aten::cudnn_convolution_relu").run(frozen_mod.graph)
self.assertEqual(mod_eager(inp), frozen_mod(inp))
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
def test_incompatible_perf_formats(self):
with set_default_dtype(torch.float):
class Mod(nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 64, 3, 2)
self.max_pool = torch.nn.MaxPool2d(111, 111)
def forward(self, x):
a = self.conv(x)
b = self.max_pool(a)
return a + b
model = Mod()
model.eval()
mod = torch.jit.freeze(torch.jit.script(model))
N, C, H, W, = 10, 3, 224, 224
inp = torch.randn(N, C, H, W)
self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
self.assertEqual(model(inp), mod(inp))
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
def test_pool2d_batchnorm(self):
with set_default_dtype(torch.float):
pooling_layers = [torch.nn.AdaptiveAvgPool2d(4),
# torch.nn.AdaptiveMaxPool2d(4), # return tuples
torch.nn.MaxPool2d(4),
torch.nn.AvgPool2d(4),
torch.nn.BatchNorm2d(64).eval()]
for pl in pooling_layers:
sub_model = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 2, 2), torch.nn.ReLU(), pl, torch.nn.Hardswish())
sub_model.eval()
mod = torch.jit.freeze(torch.jit.script(sub_model))
N, C, H, W, = 10, 3, 224, 224
inp = torch.randn(N, C, H, W)
# these two passes needed to remove
# a size check in BatchNorm2d
removeExceptions(mod.graph)
self.run_pass('dce', mod.graph)
self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
FileCheck().check("aten::to_dense").check_next("return").run(mod.graph)
self.assertEqual(sub_model(inp), mod(inp))
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
def test_pool3d_batchnorm(self):
with set_default_dtype(torch.float):
pooling_layers = [torch.nn.MaxPool3d(4),
# torch.nn.AdaptiveAvgPool3d(4), # no ideep bindings
# torch.nn.AdaptiveMaxPool3d(4), # return tuples
torch.nn.AvgPool3d(4),
torch.nn.BatchNorm3d(64).eval()]
for pl in pooling_layers:
sub_model = torch.nn.Sequential(torch.nn.Conv3d(3, 64, 2, 2), torch.nn.ReLU(), pl, torch.nn.Hardswish())
sub_model.eval()
mod = torch.jit.freeze(torch.jit.script(sub_model))
N, C, H, W, D = 10, 3, 64, 64, 64
inp = torch.randn(N, C, D, H, W)
# these two passes needed to remove
# a size check in BatchNorm2d
removeExceptions(mod.graph)
self.run_pass('dce', mod.graph)
self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
FileCheck().check("aten::to_dense").check_next("return").run(mod.graph)
self.assertEqual(sub_model(inp), mod(inp))
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
@skipIfNoTorchVision
def test_layernorm(self):
with set_default_dtype(torch.float):
class ResidualLayernorm(torch.nn.Module):
def __init__(self, op, layernorm, **kwargs):
super(ResidualLayernorm, self).__init__()
self.op = op
self.layernorm = layernorm
def forward(self, x):
y = self.op(x)
return self.layernorm(y) + y
model = torchvision.models.resnet18()
N, C, H, W, = 10, 3, 224, 224
for param in ((model.conv1, [W // 2], torch.randn(N, C, H, W)),
(model.conv1, [H // 2, W // 2], torch.randn(N, C, H, W)),
(torch.nn.Linear(H, W), [W], torch.randn(N, C, W)),):
for layernorm in (torch.nn.LayerNorm(param[1]),
torch.nn.LayerNorm(param[1], elementwise_affine=False)):
# to generate non inplace tests we extend the use of layernorm's input
for inplace in (True, False):
sub_model = torch.nn.Sequential(param[0], layernorm) if inplace else ResidualLayernorm(param[0], layernorm)
sub_model.eval()
mod = torch.jit.freeze(torch.jit.script(sub_model))
self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
# if weight and bias are present and shape is the last dimension
# we should convert `aten::layer_norm` to `prim::MKLDNNLayerNorm`
if layernorm.elementwise_affine and len(param[1]) == 1:
inplace_suffix = "_" if inplace else ""
(FileCheck().check("prim::MKLDNNLayerNorm" + inplace_suffix).
check_count("aten::to_dense", 1, exactly=True).run(mod.graph))
else:
FileCheck().check_count("aten::to_dense", 1, exactly=True).check("aten::layer_norm").run(mod.graph)
self.assertEqual(sub_model(param[2]), mod(param[2]), rtol=1e-04, atol=1e-04)
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
@skipIfNoTorchVision
def test_conv_hardswish(self):
with set_default_dtype(torch.float):
class Clamp(torch.nn.Module):
def __init__(self, min_val, max_val, **kwargs):
super(Clamp, self).__init__()
self.min_val = min_val
self.max_val = max_val
def forward(self, x):
return torch.clamp(x, self.min_val, self.max_val)
N, C, H, W, = 10, 3, 224, 224
activations = [
torch.nn.Hardswish(),
torch.nn.Hardsigmoid(),
torch.nn.ReLU6(),
torch.nn.Tanh(),
torch.nn.Hardtanh(0., 6.),
torch.nn.Hardtanh(1., 100.),
torch.nn.Hardtanh(-100., -1.),
torch.nn.GELU(),
Clamp(-100., -1.),
Clamp(1., 100.),
Clamp(0., 6.),
Clamp(-1., 0.),
]
model = torchvision.models.resnet18()
for activation in activations:
sub_model = torch.nn.Sequential(model.conv1, activation)
sub_model.eval()
mod = torch.jit.freeze(torch.jit.script(sub_model))
inp = torch.randn(N, C, H, W)
self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
FileCheck().check_count("aten::to_dense", 1, exactly=True).run(mod.graph)
self.assertEqual(sub_model(inp), mod(inp))
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
def test_hardswish_hardsigmoid(self):
with set_default_dtype(torch.float):
op_map = {
'prim::MKLDNNHardSwish' : F.hardswish,
'prim::MKLDNNHardSigmoid' : F.hardsigmoid,
}
input_sizes = ([0], [1], [3], [1, 3, 8, 8])
for (mkldnn_opname, aten_op) in op_map.items():
for size in input_sizes:
for inplace in (True, False):
inplace_str = "_" if inplace else ""
inplace_tgt = "%34" if inplace else "%35"
graph_str = f"""graph(%input.1 : Tensor):
%33 : None = prim::Constant()
%34 : Tensor = aten::to_mkldnn(%input.1, %33)
%35 : Tensor = {mkldnn_opname}{inplace_str}(%34)
return ({inplace_tgt})
"""
g = parse_ir(graph_str)
m = self.createFunctionFromGraph(g)
x = torch.rand(size)
# `inplace=False` is intentional, otherwise we modify the input
# and we aren't testing aten impls anyways
self.assertEqual(aten_op(x, inplace=False), m(x).to_dense())
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
def test_scalar_mul(self):
with set_default_dtype(torch.float):
class Mod(nn.Module):
def __init__(self):
super().__init__()
self.mod = nn.Linear(20, 20)
def forward(self, x):
a1 = self.mod(x) * 4
return a1 * 4 + a1 * 5.
mod = Mod().eval()
scripted = torch.jit.freeze(torch.jit.script(mod))
optimized = torch.jit.optimize_for_inference(scripted)
inp = torch.rand([20, 20])
print(optimized.graph)
# a1 cant be inplaced for first use, can for second
FileCheck().check("ScalarMul_").check("ScalarMul(").check("ScalarMul_").run(optimized.graph)
self.assertEqual(optimized(inp), mod(inp))
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
def test_optimize_for_inference(self):
with set_default_dtype(torch.float):
mod = nn.Linear(20, 30).eval()
scripted_mod = torch.jit.script(mod)
optimized = torch.jit.optimize_for_inference(scripted_mod)
FileCheck().check("to_mkldnn").run(optimized.graph)
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
optimized = torch.jit.optimize_for_inference(scripted_mod)
FileCheck().check("to_mkldnn").run(optimized.graph)
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
class TestMKLDNNReinplacing(JitTestCase):
def setUp(self):
self.default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float)
def tearDown(self):
torch.set_default_dtype(self.default_dtype)
def getConv(self):
return nn.Conv2d(3, 32, kernel_size=3, stride=2).eval()
def getInput(self):
return torch.rand([4, 3, 4, 4])
def freezeAndConvert(self, mod):
mod = torch.jit.freeze(torch.jit.script(mod.eval()))
self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
return mod
def checkResults(self, mod1, mod2):
inp = self.getInput()
self.assertEqual(mod1(inp), mod2(inp))
def test_successful(self):
# simple conv-relu
mod_eager = nn.Sequential(self.getConv(), nn.Hardswish(), nn.ReLU())
mod = self.freezeAndConvert(mod_eager)
FileCheck().check("mkldnn_convolution").check_next("prim::MKLDNNHardSwish_").check_next("aten::relu_").run(mod.graph)
self.checkResults(mod_eager, mod)
def test_merge_liveness(self):
class Mod(nn.Module):
def __init__(self, tensor):
super().__init__()
self.tensor = tensor
def forward(self, x):
# this mul can be inplaced since x is dead after this use
temporary = x * self.tensor
# temporary livespan is the return node,
# add can not be inplaced
return temporary + temporary, temporary
mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
mod = self.freezeAndConvert(mod_eager)
FileCheck().check("aten::mul_").check_not("aten::add_").run(mod.graph)
self.checkResults(mod_eager, mod)
def test_always_alive_values(self):
class Mod(nn.Module):
def __init__(self, tensor):
super().__init__()
self.tensor = tensor
def forward(self, x):
# x can't be inplaced because its a return value,
# check that the inplacing pass doesnt try to inplace
# self.tensor because its always alive
return x * self.tensor, x
mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
mod = self.freezeAndConvert(mod_eager)
FileCheck().check_not("aten::mul_").run(mod.graph)
self.checkResults(mod_eager, mod)
conv = self.getConv()
class Mod(nn.Module):
def __init__(self):
super().__init__()
self.tensor = torch.rand([4, 32, 1, 1])
self.conv = conv
def forward(self, x):
# the shapes dont add up on this just testing a particular pattern
conv_output = self.conv(x)
return conv_output, self.conv(torch.add(x, x))
mod = self.freezeAndConvert(Mod())
# x is an input to the graph, and so it should not be inplaced
# in the torch.add(x, x) call
FileCheck().check_not("aten::add_").run(mod.graph)
def test_switch_inputs_to_inplace(self):
class Mod(nn.Module):
def __init__(self, tensor):
super().__init__()
self.tensor = tensor
def forward(self, x):
# self.tensor cannot be inplaced, however x can,
# and bc add is commutative we can reverse inputs to add_
return self.tensor + x
mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
mod = self.freezeAndConvert(mod_eager)
FileCheck().check("aten::add_").run(mod.graph)
self.checkResults(mod_eager, mod)