blob: d2600fb1559a0249d97695a70b0f6c4bf7993f6d [file] [log] [blame]
import torch
import torch.jit
from torch.quantization._quantize_script import script_qconfig
from torch.quantization._quantize_script import prepare_dynamic_script
from torch.quantization import default_qconfig
from torch.testing._internal.common_utils import run_tests
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import attrs_with_prefix
from torch.testing._internal.jit_utils import JitTestCase
class TestScript(JitTestCase):
def test_prepare_dynamic(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc = torch.nn.Linear(5, 5)
def forward(self, x):
return self.fc(x)
m = torch.jit.script(M())
m = prepare_dynamic_script(m, {'': script_qconfig(default_qconfig)})
# for input of FC for dynamic quant
assert len(attrs_with_prefix(m, '_observer_')) == 1
# for weight
assert len(attrs_with_prefix(m.fc, '_observer_')) == 1
FileCheck().check('Observer = prim::GetAttr[name="_observer_') \
.check('prim::GetAttr[name="fc"]') \
.check('prim::CallMethod') \
.check_not('Observer = prim::GetAttr[name="_observer_') \
.run(m.graph)
def test_prepare_dynamic_child_qconfig(self):
class Sub(torch.nn.Module):
def __init__(self):
super(Sub, self).__init__()
self.fc = torch.nn.Linear(5, 5)
def forward(self, x):
return self.fc(x)
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
self.sub = Sub()
def forward(self, x):
return self.sub(self.conv(x))
m = torch.jit.script(M())
# only quantize child module.
m = prepare_dynamic_script(m, {'sub.fc': script_qconfig(default_qconfig)})
# input of sub for dynamic quant
assert len(attrs_with_prefix(m, '_observer_')) == 1
# not quantized
assert len(attrs_with_prefix(m.conv, '_observer_')) == 0
# no observers since we observe in the outer most call site
assert len(attrs_with_prefix(m.sub, '_observer_')) == 0
# weight of linear
assert len(attrs_with_prefix(m.sub.fc, '_observer_')) == 1
FileCheck().check('prim::GetAttr[name="sub') \
.check('prim::CallMethod') \
.check('Observer = prim::GetAttr[name="_observer_') \
.check('prim::CallMethod') \
.check_not('Observer = prim::GetAttr[name="_observer_') \
.run(m.graph)
if __name__ == "__main__":
run_tests()