Revert D17458232: Fake quantization enhancements for QAT/PTQ support
Test Plan: revert-hammer
Differential Revision:
D17458232
Original commit changeset: f44380c60f1a
fbshipit-source-id: 64a244c720b61fa912bacbb23fcbf9faed0757c2
diff --git a/torch/quantization/fake_quantize.py b/torch/quantization/fake_quantize.py
index 6176f8c..88f24bf 100644
--- a/torch/quantization/fake_quantize.py
+++ b/torch/quantization/fake_quantize.py
@@ -1,7 +1,7 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
from torch.nn import Module
-from .observer import MinMaxObserver, _with_args
+from .observer import default_observer, _with_args
class FakeQuantize(Module):
''' Simulate the quantize and dequantize operations in training time.
@@ -14,7 +14,7 @@
'''
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
- quant_min=0, quant_max=255, reduce_range=False):
+ quant_min=0, quant_max=255):
super(FakeQuantize, self).__init__()
assert torch.iinfo(dtype).min <= quant_min, 'quant_min out of bound'
assert quant_min <= quant_max, \
@@ -24,45 +24,36 @@
self.qscheme = qscheme
self.quant_min = quant_min
self.quant_max = quant_max
- self.fake_quant_enabled = True
- self.observer_enabled = True
- self.observer = MinMaxObserver.with_args(dtype=dtype, qscheme=qscheme, reduce_range=reduce_range)()
+ self.enabled = True
+ self.observer = default_observer(dtype=dtype, qscheme=qscheme)
self.scale = None
self.zero_point = None
- def enable_fake_quant(self, enabled=True):
- self.fake_quant_enabled = enabled
+ def enable(self, enabled=True):
+ self.enabled = enabled
return self
- def disable_fake_quant(self):
- return self.enable_fake_quant(False)
-
- def enable_observer(self, enabled=True):
- self.observer_enabled = enabled
-
- def disable_observer(self):
- return self.enable_observer(False)
+ def disable(self):
+ return self.enable(False)
def calculate_qparams(self):
return self.observer.calculate_qparams()
def forward(self, X):
- if self.observer_enabled:
- X = self.observer(X)
+ if self.enabled:
+ self.observer(X)
scale, zero_point = self.calculate_qparams()
self.scale, self.zero_point = float(scale), int(zero_point)
- if self.fake_quant_enabled:
- X = torch.fake_quantize_per_tensor_affine(X, self.scale, self.zero_point, self.quant_min, self.quant_max)
+ X = torch.fake_quantize_per_tensor_affine(
+ X, self.scale, self.zero_point, self.quant_min,
+ self.quant_max)
return X
with_args = classmethod(_with_args)
- def extra_repr(self):
- return 'fake_quant_enabled={}, observer_enabled={},\
- scale={}, zero_point={}'.format(
- self.fake_quant_enabled, self.observer_enabled,
- self.scale, self.zero_point)
-
default_fake_quant = FakeQuantize
-default_weight_fake_quant = FakeQuantize.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric,
- quant_min=-128, quant_max=127)
+
+default_weight_fake_quant = FakeQuantize.with_args(dtype=torch.qint8,
+ qscheme=torch.per_tensor_symmetric,
+ quant_min=-128,
+ quant_max=127)