Add the FP16 weight support for LSTM in dynamic_quantize (#25975)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25975
We would like to add the FP16 weight support for the dynamic quantized LSTM.
Test Plan:
buck test mode/dev caffe2/test:quantization -- 'test_quantized_rnn \(test_quantization\.PostTrainingDynamicQuantTest\)' --print-passing-details
```
[jianyuhuang@devvm794.ftw3.facebook.com: ~/fbsource/fbcode/caffe2/test] $ buck test mode/dev caffe2/test:quantization
-- 'test_quantized_rnn \(test_quantization\.PostTrainingDynamicQuantTest\)' --print-passing-details
Building: finished in 13.4 sec (100%) 8134/8134 jobs, 81 updated
Total time: 13.9 sec
Trace available for this run at /tmp/testpilot.20190910-210241.2092790.log
TestPilot test runner for Facebook. See https://fburl.com/testpilot for details.
Testpilot build revision c86e65add357582accb6ec0be23b92c8a2c510bd fbpkg ca46e8f5b26c451a8b0b2462c11bb61d at Mon Sep 9
22:16:37 2019 by twsvcscm from /usr/local/fbprojects/packages/testinfra.testpilot/696/t.par
Discovering tests
Running 1 tests
Started new test run: https://our.intern.facebook.com/intern/testinfra/testrun/1125900050322971
✓ caffe2/test:quantization - test_quantized_rnn (test_quantization.PostTrainingDynamicQuantTest) 0.183 1/1 (passed)
Test output:
> test_quantized_rnn (test_quantization.PostTrainingDynamicQuantTest) ... ok
>
> ----------------------------------------------------------------------
> Ran 1 test in 0.184s
>
> OK
Finished test run: https://our.intern.facebook.com/intern/testinfra/testrun/1125900050322971
Summary (total time 4.35s):
PASS: 1
FAIL: 0
SKIP: 0
FATAL: 0
TIMEOUT: 0
OMIT: 0
```
Differential Revision: D17299116
fbshipit-source-id: 7fe91ece25867f2c0496f1b63fb1041e6b815166
diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp
index b0ae2df..1b01f64 100644
--- a/aten/src/ATen/native/RNN.cpp
+++ b/aten/src/ATen/native/RNN.cpp
@@ -1072,11 +1072,13 @@
check_device(_input, _params, hx);
auto input = batch_first ? _input.transpose(0, 1) : _input;
TORCH_CHECK(has_biases, "quantized LSTM requires biases");
- TORCH_CHECK(result_dtype == at::kChar || result_dtype == at::kHalf,
- "dtype is not supported");
+ TORCH_CHECK(
+ result_dtype == at::kChar || result_dtype == at::kQInt8 ||
+ result_dtype == at::kHalf,
+ "dtype is not supported");
std::tuple<Tensor, Tensor, Tensor> results;
- if (result_dtype == at::kChar) {
+ if (result_dtype == at::kChar || result_dtype == at::kQInt8) {
if (use_dynamic) {
auto params = gather_quantized_params_dynamic(_params);
results = _lstm_impl<FullLayer, FullBidirectionalLayer>(
diff --git a/test/test_quantization.py b/test/test_quantization.py
index 9f5bdcf..bad08b0 100644
--- a/test/test_quantization.py
+++ b/test/test_quantization.py
@@ -492,12 +492,20 @@
torch.nn.LSTM: torch.nn.quantized.dynamic.LSTM,
}
model_int8 = quantize_dynamic(
- model, qconfig_dynamic_dict, default_dynamic_module_mapping
+ model=model, qconfig_dict=qconfig_dynamic_dict, mapping=default_dynamic_module_mapping,
+ dtype=torch.qint8
+ )
+ model_fp16 = quantize_dynamic(
+ model=model, qconfig_dict=qconfig_dynamic_dict, mapping=default_dynamic_module_mapping,
+ dtype=torch.float16
)
cell_int8 = model_int8.lstm
+ cell_fp16 = model_fp16.lstm
assert type(cell_int8) == torch.nn.quantized.dynamic.LSTM, \
'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic'
+ assert type(cell_fp16) == torch.nn.quantized.dynamic.LSTM, \
+ 'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic'
niter = 10
x = torch.tensor([[100, -155],
@@ -549,6 +557,13 @@
for loaded_val, ref_val in zip(out_loaded, ref_out):
torch.testing.assert_allclose(loaded_val, ref_val)
+ # Compare fp16 quantized to unquantized
+ output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens)
+
+ torch.testing.assert_allclose(output_fp16, ref_out)
+ self.assertEqual(output_fp16, ref_out)
+ for out, ref in zip(final_hiddens_fp16, ref_hid):
+ torch.testing.assert_allclose(out, ref)
@unittest.skipIf(
not torch.fbgemm_is_cpu_supported(),
diff --git a/torch/nn/quantized/dynamic/modules/rnn.py b/torch/nn/quantized/dynamic/modules/rnn.py
index 66e6504..8af7026 100644
--- a/torch/nn/quantized/dynamic/modules/rnn.py
+++ b/torch/nn/quantized/dynamic/modules/rnn.py
@@ -20,7 +20,7 @@
def __init__(self, mode, input_size, hidden_size,
num_layers=1, bias=True, batch_first=False,
- dropout=0., bidirectional=False):
+ dropout=0., bidirectional=False, dtype=torch.qint8):
super(RNNBase, self).__init__()
self.mode = mode
@@ -31,6 +31,7 @@
self.batch_first = batch_first
self.dropout = float(dropout)
self.bidirectional = bidirectional
+ self.dtype = dtype
num_directions = 2 if bidirectional else 1
if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
@@ -51,48 +52,67 @@
self._all_weight_names = []
self._all_weight_values = []
-
for layer in range(num_layers):
for direction in range(num_directions):
layer_input_size = input_size if layer == 0 else hidden_size * num_directions
- def process_weights(ihhh, layer, suffix, qweight, bias):
- # for each layer, for each direction we need to quantize and pack
- # weights and pack parameters in this order:
- #
- # w_ih, w_hh
- packed_weight = \
- torch.ops.quantized.linear_prepack(qweight, bias)
- params = [packed_weight]
- pos_names = ['w']
- ret_name = ['{}_{}_l{}{}'.format(
- name, ihhh, layer, suffix) for name in pos_names]
- return params, ret_name
+ def process_weights(ihhh, layer, suffix, qweight, bias, dtype):
+ if dtype == torch.qint8:
+ # for each layer, for each direction we need to quantize and pack
+ # weights and pack parameters in this order:
+ #
+ # w_ih, w_hh
+ packed_weight = \
+ torch.ops.quantized.linear_prepack(qweight, bias)
- w_ih = torch._empty_affine_quantized(
- [gate_size, layer_input_size], scale=1, zero_point=0, dtype=torch.qint8)
- w_hh = torch._empty_affine_quantized(
- [gate_size, hidden_size], scale=1, zero_point=0, dtype=torch.qint8)
- b_ih = torch._empty_affine_quantized(
- [gate_size], scale=1, zero_point=0, dtype=torch.qint32)
- # Second bias vector included for CuDNN compatibility. Only one
- # bias vector is needed in standard definition.
- b_hh = torch._empty_affine_quantized(
- [gate_size], scale=1, zero_point=0, dtype=torch.qint32)
+ params = [packed_weight]
+ pos_names = ['w']
+ ret_name = ['{}_{}_l{}{}'.format(
+ name, ihhh, layer, suffix) for name in pos_names]
+ return params, ret_name
+ else:
+ # for each layer, for each direction we need to quantize and pack
+ # weights and pack parameters in this order:
+ #
+ # packed_ih, packed_hh, b_ih, b_hh
+ packed_weight = torch.fbgemm_pack_gemm_matrix_fp16(
+ qweight)
+
+ params = [packed_weight, bias]
+ pos_names = ['packed', 'b']
+ ret_name = ['{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names]
+ return params, ret_name
+
+ if dtype == torch.qint8:
+ w_ih = torch._empty_affine_quantized(
+ [gate_size, layer_input_size], scale=1, zero_point=0, dtype=torch.qint8)
+ w_hh = torch._empty_affine_quantized(
+ [gate_size, hidden_size], scale=1, zero_point=0, dtype=torch.qint8)
+ b_ih = torch._empty_affine_quantized(
+ [gate_size], scale=1, zero_point=0, dtype=torch.qint32)
+ # Second bias vector included for CuDNN compatibility. Only one
+ # bias vector is needed in standard definition.
+ b_hh = torch._empty_affine_quantized(
+ [gate_size], scale=1, zero_point=0, dtype=torch.qint32)
+
+ else:
+ w_ih = torch.Tensor(gate_size, layer_input_size).float()
+ w_hh = torch.Tensor(gate_size, hidden_size).float()
+ b_ih = torch.Tensor(gate_size).float()
+ # Second bias vector included for CuDNN compatibility. Only one
+ # bias vector is needed in standard definition.
+ b_hh = torch.Tensor(gate_size).float()
suffix = '_reverse' if direction == 1 else ''
ih_params, ih_param_names = process_weights(
- 'ih', layer, suffix, w_ih, b_ih)
+ 'ih', layer, suffix, w_ih, b_ih, dtype)
hh_params, hh_param_names = process_weights(
- 'hh', layer, suffix, w_hh, b_hh)
+ 'hh', layer, suffix, w_hh, b_hh, dtype)
for (ih, ih_name), (hh, hh_name) in zip(zip(ih_params, ih_param_names), zip(hh_params, hh_param_names)):
-
self._all_weight_names.extend([ih_name, hh_name])
self._all_weight_values.extend([ih, hh])
-
-
def check_input(self, input, batch_sizes):
# type: (Tensor, Optional[Tensor]) -> None
expected_input_dim = 2 if batch_sizes is not None else 3
@@ -150,6 +170,7 @@
self._all_weight_names,
self.__overloads__,
self.training,
+ self.dtype,
)
dynamic_vals = torch.jit.annotate(List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
@@ -173,29 +194,37 @@
self._all_weight_names = vals[8]
self.__overloads__ = vals[9]
self.training = vals[10]
+ self.dtype = vals[11]
self._all_weight_values = []
for i in range(len(self._all_weight_names)):
self._all_weight_values.append(torch.ops.quantized.linear_prepack(*dynamic_vals[i]))
@classmethod
- def from_float(cls, mod):
+ def from_float(cls, mod, dtype=torch.qint8):
assert type(mod) == torch.nn.LSTM, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM'
assert hasattr(
mod, 'qconfig'), 'Input float module must have qconfig defined'
- if mod.qconfig is not None and mod.qconfig.weight() is not None:
- weight_observer = mod.qconfig.weight()
- else:
- # We have the circular import issues if we import the qconfig in the beginning of this file:
- # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
- # import until we need it.
- from torch.quantization.QConfig import default_dynamic_qconfig
- weight_observer = default_dynamic_qconfig.weight()
- assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
+
+ supported_scalar_types = [torch.qint8, torch.float16]
+ if dtype not in supported_scalar_types:
+ raise RuntimeError('Unsupported dtype: {}'.format(dtype))
+
+ # When dtype = torch.float16, we don't need weight_observer
+ if dtype == torch.qint8:
+ if mod.qconfig is not None and mod.qconfig.weight() is not None:
+ weight_observer = mod.qconfig.weight()
+ else:
+ # We have the circular import issues if we import the qconfig in the beginning of this file:
+ # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
+ # import until we need it.
+ from torch.quantization.QConfig import default_dynamic_qconfig
+ weight_observer = default_dynamic_qconfig.weight()
+ assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
if mod.mode == 'LSTM':
qRNNBase = LSTM(mod.input_size, mod.hidden_size, mod.num_layers,
- mod.bias, mod.batch_first, mod.dropout, mod.bidirectional)
+ mod.bias, mod.batch_first, mod.dropout, mod.bidirectional, dtype)
num_directions = 2 if mod.bidirectional else 1
@@ -211,38 +240,51 @@
for direction in range(num_directions):
layer_input_size = qRNNBase.input_size if layer == 0 else qRNNBase.hidden_size * num_directions
- def process_weights(ihhh, layer, suffix):
+ def process_weights(ihhh, layer, suffix, dtype):
weight_name = 'weight_{}_l{}{}'.format(ihhh, layer, suffix)
bias_name = 'bias_{}_l{}{}'.format(ihhh, layer, suffix)
weight = getattr(mod, weight_name)
bias = getattr(mod, bias_name)
- # for each layer, for each direction we need to quantize and pack
- # weights and pack parameters in this order:
- #
- # w_ih, w_hh
- weight_observer(weight)
- wt_scale, wt_zp = weight_observer.calculate_qparams()
- qweight = torch.quantize_linear(
- weight.float(), float(wt_scale), int(wt_zp), torch.qint8)
- packed_weight = \
- torch.ops.quantized.linear_prepack(qweight, bias)
- params = [packed_weight]
- pos_names = ['w']
- ret_name = ['{}_{}_l{}{}'.format(
- name, ihhh, layer, suffix) for name in pos_names]
- return params, ret_name
+ if dtype == torch.qint8:
+ # for each layer, for each direction we need to quantize and pack
+ # weights and pack parameters in this order:
+ #
+ # w_ih, w_hh
+ weight_observer(weight)
+ wt_scale, wt_zp = weight_observer.calculate_qparams()
+ qweight = torch.quantize_linear(
+ weight.float(), float(wt_scale), int(wt_zp), torch.qint8)
+ packed_weight = \
+ torch.ops.quantized.linear_prepack(qweight, bias)
+
+ params = [packed_weight]
+ pos_names = ['w']
+ ret_name = ['{}_{}_l{}{}'.format(
+ name, ihhh, layer, suffix) for name in pos_names]
+ return params, ret_name
+ else:
+ # for each layer, for each direction we need to quantize and pack
+ # weights and pack parameters in this order:
+ #
+ # packed_ih, packed_hh, b_ih, b_hh
+ packed_weight = torch.fbgemm_pack_gemm_matrix_fp16(
+ weight.float())
+
+ params = [packed_weight, bias]
+ pos_names = ['packed', 'b']
+ ret_name = ['{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names]
+ return params, ret_name
suffix = '_reverse' if direction == 1 else ''
- ih_params, ih_param_names = process_weights('ih', layer, suffix)
- hh_params, hh_param_names = process_weights('hh', layer, suffix)
+ ih_params, ih_param_names = process_weights('ih', layer, suffix, dtype)
+ hh_params, hh_param_names = process_weights('hh', layer, suffix, dtype)
for (ih, ih_name), (hh, hh_name) in zip(zip(ih_params, ih_param_names), zip(hh_params, hh_param_names)):
qRNNBase._all_weight_names.extend([ih_name, hh_name])
qRNNBase._all_weight_values.extend([ih, hh])
-
return qRNNBase
@@ -273,7 +315,7 @@
result = _VF.quantized_lstm(input, hx, self._all_weight_values, self.bias, self.num_layers,
float(self.dropout), self.training, self.bidirectional,
- self.batch_first, dtype=torch.int8, use_dynamic=True)
+ self.batch_first, dtype=self.dtype, use_dynamic=True)
output = result[0]
hidden = result[1:]
@@ -330,5 +372,5 @@
return self.forward_tensor(input, hx)
@classmethod
- def from_float(cls, mod):
- return super(LSTM, cls).from_float(mod)
+ def from_float(cls, mod, dtype=torch.qint8):
+ return super(LSTM, cls).from_float(mod, dtype)
diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py
index 8b6a8d6..4c60294 100644
--- a/torch/quantization/__init__.py
+++ b/torch/quantization/__init__.py
@@ -26,7 +26,7 @@
'Observer', 'WeightObserver', 'observer', 'default_observer',
'default_weight_observer',
# QConfig
- 'QConfig', 'default_qconfig',
+ 'QConfig', 'default_qconfig', 'default_dynamic_qconfig',
# QAT utilities
'default_qat_qconfig', 'prepare_qat', 'quantize_qat',
# module transformations
diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py
index fff1b04..5c0fee8 100644
--- a/torch/quantization/quantize.py
+++ b/torch/quantization/quantize.py
@@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
+import torch
import torch.nn as nn
import torch.nn._intrinsic as nni
import torch.nn._intrinsic.quantized as nniq
@@ -260,7 +261,7 @@
nn.LSTM : default_dynamic_qconfig,
}
-def quantize_dynamic(model, qconfig_dict=DEFAULT_QCONFIG_DICT, mapping=DEFAULT_DYNAMIC_MODULE_MAPPING):
+def quantize_dynamic(model, qconfig_dict=DEFAULT_QCONFIG_DICT, mapping=DEFAULT_DYNAMIC_MODULE_MAPPING, dtype=torch.qint8):
r"""Converts a float model to dynamic quantized model.
Perform dynamic training and output a quantized model.
@@ -268,7 +269,7 @@
model = copy.deepcopy(model)
model.eval()
propagate_qconfig(model, qconfig_dict)
- convert(model, mapping)
+ convert(model, mapping, dtype)
return model
def prepare_qat(model):
@@ -294,7 +295,7 @@
convert(model)
return model
-def convert(module, mapping=DEFAULT_MODULE_MAPPING):
+def convert(module, mapping=DEFAULT_MODULE_MAPPING, dtype=torch.qint8):
r"""Converts the float module with observers(where we can get quantization
parameters) to a quantized module.
Args:
@@ -311,13 +312,13 @@
for name, mod in module.named_children():
if type(mod) not in SWAPPABLE_MODULES:
- convert(mod, mapping)
- reassign[name] = swap_module(mod, mapping)
+ convert(mod, mapping, dtype)
+ reassign[name] = swap_module(mod, mapping, dtype)
for key, value in reassign.items():
module._modules[key] = value
-def swap_module(mod, mapping):
+def swap_module(mod, mapping, dtype=torch.qint8):
r"""Swaps the module if it has a quantized counterpart and it has an
`observer` attached.
@@ -331,7 +332,14 @@
new_mod = mod
if hasattr(mod, 'qconfig') and mod.qconfig is not None:
if type(mod) in mapping:
- new_mod = mapping[type(mod)].from_float(mod)
+ supported_scalar_types = [torch.qint8, torch.float16]
+ if dtype not in supported_scalar_types:
+ raise RuntimeError('Unsupported dtype: {}'.format(dtype))
+ if dtype == torch.qint8:
+ new_mod = mapping[type(mod)].from_float(mod)
+ elif dtype == torch.float16:
+ # We want to support float16 dynamic quantization
+ new_mod = mapping[type(mod)].from_float(mod, dtype)
return new_mod
def dump_tensor(mod, target_dict, prefix=""):