Use `default_observer` and `default_weight_observer` in tests (#31424)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31424
att
Test Plan:
test_jit.py
Imported from OSS
Differential Revision: D19162368
fbshipit-source-id: 33b95ba643eeeae942283bbc33f7ceda8d14c431
diff --git a/test/test_jit.py b/test/test_jit.py
index 9af15b0..a679d31 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -305,26 +305,6 @@
super(FooToPickle, self).__init__()
self.bar = torch.jit.ScriptModule()
-class Observer(torch.nn.Module):
- def __init__(self, dtype=torch.quint8):
- super(Observer, self).__init__()
- self.dtype = dtype
-
- def forward(self, x):
- return x
-
- @torch.jit.export
- def calculate_qparams(self):
- return torch.tensor([2.0]), torch.tensor([3])
-
- @torch.jit.export
- def get_qparams(self):
- return self.calculate_qparams()
-
-class WeightObserver(Observer):
- def __init__(self):
- super(WeightObserver, self).__init__(torch.qint8)
-
class TestJit(JitTestCase):
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
def test_large_nbr_kernel_args(self):
@@ -992,7 +972,7 @@
return self.conv(x)
m = torch.jit.script(M())
- observer = torch.jit.script(Observer())
+ observer = torch.jit.script(default_observer())
qconfig_dict = {
'':
QConfig(
@@ -1054,7 +1034,7 @@
.run(str(s))
m = torch.jit.script(M())
- observer = torch.jit.script(Observer())
+ observer = torch.jit.script(default_observer())
torch._C._jit_pass_constant_propagation(get_forward_graph(m._c))
qconfig = QConfig(
@@ -1105,7 +1085,7 @@
# When we change the implementation to clone the module before
# inserting observers, we can remove this copy
m = m.copy()
- observer = torch.jit.script(Observer())
+ observer = torch.jit.script(default_observer())
qconfig_dict = {
'':
QConfig(
@@ -1140,8 +1120,8 @@
return F.relu(self.conv(x))
m = torch.jit.script(M())
- observer = torch.jit.script(Observer())
- weight_observer = torch.jit.script(WeightObserver())
+ observer = torch.jit.script(default_observer())
+ weight_observer = torch.jit.script(default_weight_observer())
qconfig_dict = {
'':
QConfig(
@@ -1229,7 +1209,7 @@
return b + c
m = torch.jit.script(M())
- observer = torch.jit.script(Observer())
+ observer = torch.jit.script(default_observer())
qconfig_dict = {
'':
QConfig(