Fix scriptability for Observer (#25197)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25197
Ensure that observer code remains scriptable after addition of warnings
ghstack-source-id: 89022474
Test Plan: buck test caffe2/test:quantization -- 'test_observer_scriptable \(test_quantization\.ObserverTest\)' --print-passing-details
Differential Revision: D17059486
fbshipit-source-id: 70ea9ee39f0b896c7801e168666f88c156dbf15b
diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py
index a788851..ab87d5f 100644
--- a/torch/quantization/observer.py
+++ b/torch/quantization/observer.py
@@ -53,19 +53,22 @@
qmin, qmax = -128, 127
else:
qmin, qmax = 0, 255
- scale = 1.0
- zero_point = 0
# We pull these out so that TorchScript optional type refinement works.
# We may be able to remove this in the future if TorchScript supports that
# feature on attributes
min_val = self.min_val
max_val = self.max_val
if max_val is None or min_val is None:
- warnings.warn("must run observer before calling calculate_qparams")
+ warnings.warn("must run observer before calling calculate_qparams.\
+ Returning default scale and zero point ")
+ return torch.tensor([1.0]), torch.tensor([0])
+ max_val, min_val = float(max_val), float(min_val)
+ min_val = min(0.0, min_val)
+ max_val = max(0.0, max_val)
+ if max_val == min_val:
+ scale = 1.0
+ zero_point = 0
else:
- max_val, min_val = self.max_val.item(), self.min_val.item()
- min_val = min(0.0, self.min_val.item())
- max_val = max(0.0, self.max_val.item())
if self.qscheme == torch.per_tensor_symmetric:
max_val = max(-min_val, max_val)
scale = max_val / ((qmax - qmin) / 2)