Add the FP16 weight support for LSTM in dynamic_quantize (#25975)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25975

We would like to add the FP16 weight support for the dynamic quantized LSTM.

Test Plan:
buck test mode/dev caffe2/test:quantization -- 'test_quantized_rnn \(test_quantization\.PostTrainingDynamicQuantTest\)'  --print-passing-details

```
[jianyuhuang@devvm794.ftw3.facebook.com: ~/fbsource/fbcode/caffe2/test] $ buck test mode/dev caffe2/test:quantization
-- 'test_quantized_rnn \(test_quantization\.PostTrainingDynamicQuantTest\)'  --print-passing-details
Building: finished in 13.4 sec (100%) 8134/8134 jobs, 81 updated
  Total time: 13.9 sec
Trace available for this run at /tmp/testpilot.20190910-210241.2092790.log
TestPilot test runner for Facebook. See https://fburl.com/testpilot for details.
Testpilot build revision c86e65add357582accb6ec0be23b92c8a2c510bd fbpkg ca46e8f5b26c451a8b0b2462c11bb61d at Mon Sep  9
22:16:37 2019 by twsvcscm from /usr/local/fbprojects/packages/testinfra.testpilot/696/t.par
Discovering tests
Running 1 tests
Started new test run: https://our.intern.facebook.com/intern/testinfra/testrun/1125900050322971
      ✓ caffe2/test:quantization - test_quantized_rnn (test_quantization.PostTrainingDynamicQuantTest) 0.183 1/1 (passed)
Test output:
> test_quantized_rnn (test_quantization.PostTrainingDynamicQuantTest) ... ok
>
> ----------------------------------------------------------------------
> Ran 1 test in 0.184s
>
> OK
Finished test run: https://our.intern.facebook.com/intern/testinfra/testrun/1125900050322971
Summary (total time 4.35s):
  PASS: 1
  FAIL: 0
  SKIP: 0
  FATAL: 0
  TIMEOUT: 0
  OMIT: 0
```

Differential Revision: D17299116

fbshipit-source-id: 7fe91ece25867f2c0496f1b63fb1041e6b815166
diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp
index b0ae2df..1b01f64 100644
--- a/aten/src/ATen/native/RNN.cpp
+++ b/aten/src/ATen/native/RNN.cpp
@@ -1072,11 +1072,13 @@
   check_device(_input, _params, hx);
   auto input = batch_first ? _input.transpose(0, 1) : _input;
   TORCH_CHECK(has_biases, "quantized LSTM requires biases");
-  TORCH_CHECK(result_dtype == at::kChar || result_dtype == at::kHalf,
-              "dtype is not supported");
+  TORCH_CHECK(
+      result_dtype == at::kChar || result_dtype == at::kQInt8 ||
+          result_dtype == at::kHalf,
+      "dtype is not supported");
 
   std::tuple<Tensor, Tensor, Tensor> results;
-  if (result_dtype == at::kChar) {
+  if (result_dtype == at::kChar || result_dtype == at::kQInt8) {
     if (use_dynamic) {
       auto params = gather_quantized_params_dynamic(_params);
       results = _lstm_impl<FullLayer, FullBidirectionalLayer>(
diff --git a/test/test_quantization.py b/test/test_quantization.py
index 9f5bdcf..bad08b0 100644
--- a/test/test_quantization.py
+++ b/test/test_quantization.py
@@ -492,12 +492,20 @@
             torch.nn.LSTM: torch.nn.quantized.dynamic.LSTM,
         }
         model_int8 = quantize_dynamic(
-            model, qconfig_dynamic_dict, default_dynamic_module_mapping
+            model=model, qconfig_dict=qconfig_dynamic_dict, mapping=default_dynamic_module_mapping,
+            dtype=torch.qint8
+        )
+        model_fp16 = quantize_dynamic(
+            model=model, qconfig_dict=qconfig_dynamic_dict, mapping=default_dynamic_module_mapping,
+            dtype=torch.float16
         )
         cell_int8 = model_int8.lstm
+        cell_fp16 = model_fp16.lstm
 
         assert type(cell_int8) == torch.nn.quantized.dynamic.LSTM, \
             'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic'
+        assert type(cell_fp16) == torch.nn.quantized.dynamic.LSTM, \
+            'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic'
 
         niter = 10
         x = torch.tensor([[100, -155],
@@ -549,6 +557,13 @@
         for loaded_val, ref_val in zip(out_loaded, ref_out):
             torch.testing.assert_allclose(loaded_val, ref_val)
 
+        # Compare fp16 quantized to unquantized
+        output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens)
+
+        torch.testing.assert_allclose(output_fp16, ref_out)
+        self.assertEqual(output_fp16, ref_out)
+        for out, ref in zip(final_hiddens_fp16, ref_hid):
+            torch.testing.assert_allclose(out, ref)
 
 @unittest.skipIf(
     not torch.fbgemm_is_cpu_supported(),
diff --git a/torch/nn/quantized/dynamic/modules/rnn.py b/torch/nn/quantized/dynamic/modules/rnn.py
index 66e6504..8af7026 100644
--- a/torch/nn/quantized/dynamic/modules/rnn.py
+++ b/torch/nn/quantized/dynamic/modules/rnn.py
@@ -20,7 +20,7 @@
 
     def __init__(self, mode, input_size, hidden_size,
                  num_layers=1, bias=True, batch_first=False,
-                 dropout=0., bidirectional=False):
+                 dropout=0., bidirectional=False, dtype=torch.qint8):
         super(RNNBase, self).__init__()
 
         self.mode = mode
@@ -31,6 +31,7 @@
         self.batch_first = batch_first
         self.dropout = float(dropout)
         self.bidirectional = bidirectional
+        self.dtype = dtype
         num_directions = 2 if bidirectional else 1
 
         if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
@@ -51,48 +52,67 @@
 
         self._all_weight_names = []
         self._all_weight_values = []
-
         for layer in range(num_layers):
             for direction in range(num_directions):
                 layer_input_size = input_size if layer == 0 else hidden_size * num_directions
 
-                def process_weights(ihhh, layer, suffix, qweight, bias):
-                    # for each layer, for each direction we need to quantize and pack
-                    # weights and pack parameters in this order:
-                    #
-                    #   w_ih, w_hh
-                    packed_weight = \
-                        torch.ops.quantized.linear_prepack(qweight, bias)
-                    params = [packed_weight]
-                    pos_names = ['w']
-                    ret_name = ['{}_{}_l{}{}'.format(
-                        name, ihhh, layer, suffix) for name in pos_names]
-                    return params, ret_name
+                def process_weights(ihhh, layer, suffix, qweight, bias, dtype):
+                    if dtype == torch.qint8:
+                        # for each layer, for each direction we need to quantize and pack
+                        # weights and pack parameters in this order:
+                        #
+                        #   w_ih, w_hh
+                        packed_weight = \
+                            torch.ops.quantized.linear_prepack(qweight, bias)
 
-                w_ih = torch._empty_affine_quantized(
-                    [gate_size, layer_input_size], scale=1, zero_point=0, dtype=torch.qint8)
-                w_hh = torch._empty_affine_quantized(
-                    [gate_size, hidden_size], scale=1, zero_point=0, dtype=torch.qint8)
-                b_ih = torch._empty_affine_quantized(
-                    [gate_size], scale=1, zero_point=0, dtype=torch.qint32)
-                # Second bias vector included for CuDNN compatibility. Only one
-                # bias vector is needed in standard definition.
-                b_hh = torch._empty_affine_quantized(
-                    [gate_size], scale=1, zero_point=0, dtype=torch.qint32)
+                        params = [packed_weight]
+                        pos_names = ['w']
+                        ret_name = ['{}_{}_l{}{}'.format(
+                            name, ihhh, layer, suffix) for name in pos_names]
+                        return params, ret_name
+                    else:
+                        # for each layer, for each direction we need to quantize and pack
+                        # weights and pack parameters in this order:
+                        #
+                        #   packed_ih, packed_hh, b_ih, b_hh
+                        packed_weight = torch.fbgemm_pack_gemm_matrix_fp16(
+                            qweight)
+
+                        params = [packed_weight, bias]
+                        pos_names = ['packed', 'b']
+                        ret_name = ['{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names]
+                        return params, ret_name
+
+                if dtype == torch.qint8:
+                    w_ih = torch._empty_affine_quantized(
+                        [gate_size, layer_input_size], scale=1, zero_point=0, dtype=torch.qint8)
+                    w_hh = torch._empty_affine_quantized(
+                        [gate_size, hidden_size], scale=1, zero_point=0, dtype=torch.qint8)
+                    b_ih = torch._empty_affine_quantized(
+                        [gate_size], scale=1, zero_point=0, dtype=torch.qint32)
+                    # Second bias vector included for CuDNN compatibility. Only one
+                    # bias vector is needed in standard definition.
+                    b_hh = torch._empty_affine_quantized(
+                        [gate_size], scale=1, zero_point=0, dtype=torch.qint32)
+
+                else:
+                    w_ih = torch.Tensor(gate_size, layer_input_size).float()
+                    w_hh = torch.Tensor(gate_size, hidden_size).float()
+                    b_ih = torch.Tensor(gate_size).float()
+                    # Second bias vector included for CuDNN compatibility. Only one
+                    # bias vector is needed in standard definition.
+                    b_hh = torch.Tensor(gate_size).float()
 
                 suffix = '_reverse' if direction == 1 else ''
                 ih_params, ih_param_names = process_weights(
-                    'ih', layer, suffix, w_ih, b_ih)
+                    'ih', layer, suffix, w_ih, b_ih, dtype)
                 hh_params, hh_param_names = process_weights(
-                    'hh', layer, suffix, w_hh, b_hh)
+                    'hh', layer, suffix, w_hh, b_hh, dtype)
 
                 for (ih, ih_name), (hh, hh_name) in zip(zip(ih_params, ih_param_names), zip(hh_params, hh_param_names)):
-
                     self._all_weight_names.extend([ih_name, hh_name])
                     self._all_weight_values.extend([ih, hh])
 
-
-
     def check_input(self, input, batch_sizes):
         # type: (Tensor, Optional[Tensor]) -> None
         expected_input_dim = 2 if batch_sizes is not None else 3
@@ -150,6 +170,7 @@
             self._all_weight_names,
             self.__overloads__,
             self.training,
+            self.dtype,
         )
 
         dynamic_vals = torch.jit.annotate(List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
@@ -173,29 +194,37 @@
         self._all_weight_names = vals[8]
         self.__overloads__ = vals[9]
         self.training = vals[10]
+        self.dtype = vals[11]
 
         self._all_weight_values = []
         for i in range(len(self._all_weight_names)):
             self._all_weight_values.append(torch.ops.quantized.linear_prepack(*dynamic_vals[i]))
 
     @classmethod
-    def from_float(cls, mod):
+    def from_float(cls, mod, dtype=torch.qint8):
         assert type(mod) == torch.nn.LSTM, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM'
         assert hasattr(
             mod, 'qconfig'), 'Input float module must have qconfig defined'
-        if mod.qconfig is not None and mod.qconfig.weight() is not None:
-            weight_observer = mod.qconfig.weight()
-        else:
-            # We have the circular import issues if we import the qconfig in the beginning of this file:
-            # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
-            # import until we need it.
-            from torch.quantization.QConfig import default_dynamic_qconfig
-            weight_observer = default_dynamic_qconfig.weight()
-        assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
+
+        supported_scalar_types = [torch.qint8, torch.float16]
+        if dtype not in supported_scalar_types:
+            raise RuntimeError('Unsupported dtype: {}'.format(dtype))
+
+        # When dtype = torch.float16, we don't need weight_observer
+        if dtype == torch.qint8:
+            if mod.qconfig is not None and mod.qconfig.weight() is not None:
+                weight_observer = mod.qconfig.weight()
+            else:
+                # We have the circular import issues if we import the qconfig in the beginning of this file:
+                # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
+                # import until we need it.
+                from torch.quantization.QConfig import default_dynamic_qconfig
+                weight_observer = default_dynamic_qconfig.weight()
+            assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
 
         if mod.mode == 'LSTM':
             qRNNBase = LSTM(mod.input_size, mod.hidden_size, mod.num_layers,
-                            mod.bias, mod.batch_first, mod.dropout, mod.bidirectional)
+                            mod.bias, mod.batch_first, mod.dropout, mod.bidirectional, dtype)
 
         num_directions = 2 if mod.bidirectional else 1
 
@@ -211,38 +240,51 @@
             for direction in range(num_directions):
                 layer_input_size = qRNNBase.input_size if layer == 0 else qRNNBase.hidden_size * num_directions
 
-                def process_weights(ihhh, layer, suffix):
+                def process_weights(ihhh, layer, suffix, dtype):
                     weight_name = 'weight_{}_l{}{}'.format(ihhh, layer, suffix)
                     bias_name = 'bias_{}_l{}{}'.format(ihhh, layer, suffix)
 
                     weight = getattr(mod, weight_name)
                     bias = getattr(mod, bias_name)
-                    # for each layer, for each direction we need to quantize and pack
-                    # weights and pack parameters in this order:
-                    #
-                    #   w_ih, w_hh
-                    weight_observer(weight)
-                    wt_scale, wt_zp = weight_observer.calculate_qparams()
-                    qweight = torch.quantize_linear(
-                        weight.float(), float(wt_scale), int(wt_zp), torch.qint8)
-                    packed_weight = \
-                        torch.ops.quantized.linear_prepack(qweight, bias)
 
-                    params = [packed_weight]
-                    pos_names = ['w']
-                    ret_name = ['{}_{}_l{}{}'.format(
-                        name, ihhh, layer, suffix) for name in pos_names]
-                    return params, ret_name
+                    if dtype == torch.qint8:
+                        # for each layer, for each direction we need to quantize and pack
+                        # weights and pack parameters in this order:
+                        #
+                        #   w_ih, w_hh
+                        weight_observer(weight)
+                        wt_scale, wt_zp = weight_observer.calculate_qparams()
+                        qweight = torch.quantize_linear(
+                            weight.float(), float(wt_scale), int(wt_zp), torch.qint8)
+                        packed_weight = \
+                            torch.ops.quantized.linear_prepack(qweight, bias)
+
+                        params = [packed_weight]
+                        pos_names = ['w']
+                        ret_name = ['{}_{}_l{}{}'.format(
+                            name, ihhh, layer, suffix) for name in pos_names]
+                        return params, ret_name
+                    else:
+                        # for each layer, for each direction we need to quantize and pack
+                        # weights and pack parameters in this order:
+                        #
+                        #   packed_ih, packed_hh, b_ih, b_hh
+                        packed_weight = torch.fbgemm_pack_gemm_matrix_fp16(
+                            weight.float())
+
+                        params = [packed_weight, bias]
+                        pos_names = ['packed', 'b']
+                        ret_name = ['{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names]
+                        return params, ret_name
 
                 suffix = '_reverse' if direction == 1 else ''
-                ih_params, ih_param_names = process_weights('ih', layer, suffix)
-                hh_params, hh_param_names = process_weights('hh', layer, suffix)
+                ih_params, ih_param_names = process_weights('ih', layer, suffix, dtype)
+                hh_params, hh_param_names = process_weights('hh', layer, suffix, dtype)
 
                 for (ih, ih_name), (hh, hh_name) in zip(zip(ih_params, ih_param_names), zip(hh_params, hh_param_names)):
                     qRNNBase._all_weight_names.extend([ih_name, hh_name])
                     qRNNBase._all_weight_values.extend([ih, hh])
 
-
         return qRNNBase
 
 
@@ -273,7 +315,7 @@
 
         result = _VF.quantized_lstm(input, hx, self._all_weight_values, self.bias, self.num_layers,
                                     float(self.dropout), self.training, self.bidirectional,
-                                    self.batch_first, dtype=torch.int8, use_dynamic=True)
+                                    self.batch_first, dtype=self.dtype, use_dynamic=True)
         output = result[0]
         hidden = result[1:]
 
@@ -330,5 +372,5 @@
             return self.forward_tensor(input, hx)
 
     @classmethod
-    def from_float(cls, mod):
-        return super(LSTM, cls).from_float(mod)
+    def from_float(cls, mod, dtype=torch.qint8):
+        return super(LSTM, cls).from_float(mod, dtype)
diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py
index 8b6a8d6..4c60294 100644
--- a/torch/quantization/__init__.py
+++ b/torch/quantization/__init__.py
@@ -26,7 +26,7 @@
     'Observer', 'WeightObserver', 'observer', 'default_observer',
     'default_weight_observer',
     # QConfig
-    'QConfig', 'default_qconfig',
+    'QConfig', 'default_qconfig', 'default_dynamic_qconfig',
     # QAT utilities
     'default_qat_qconfig', 'prepare_qat', 'quantize_qat',
     # module transformations
diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py
index fff1b04..5c0fee8 100644
--- a/torch/quantization/quantize.py
+++ b/torch/quantization/quantize.py
@@ -1,6 +1,7 @@
 from __future__ import absolute_import, division, print_function, unicode_literals
 
 import copy
+import torch
 import torch.nn as nn
 import torch.nn._intrinsic as nni
 import torch.nn._intrinsic.quantized as nniq
@@ -260,7 +261,7 @@
     nn.LSTM : default_dynamic_qconfig,
 }
 
-def quantize_dynamic(model, qconfig_dict=DEFAULT_QCONFIG_DICT, mapping=DEFAULT_DYNAMIC_MODULE_MAPPING):
+def quantize_dynamic(model, qconfig_dict=DEFAULT_QCONFIG_DICT, mapping=DEFAULT_DYNAMIC_MODULE_MAPPING, dtype=torch.qint8):
     r"""Converts a float model to dynamic quantized model.
 
     Perform dynamic training and output a quantized model.
@@ -268,7 +269,7 @@
     model = copy.deepcopy(model)
     model.eval()
     propagate_qconfig(model, qconfig_dict)
-    convert(model, mapping)
+    convert(model, mapping, dtype)
     return model
 
 def prepare_qat(model):
@@ -294,7 +295,7 @@
     convert(model)
     return model
 
-def convert(module, mapping=DEFAULT_MODULE_MAPPING):
+def convert(module, mapping=DEFAULT_MODULE_MAPPING, dtype=torch.qint8):
     r"""Converts the float module with observers(where we can get quantization
     parameters) to a quantized module.
     Args:
@@ -311,13 +312,13 @@
 
     for name, mod in module.named_children():
         if type(mod) not in SWAPPABLE_MODULES:
-            convert(mod, mapping)
-        reassign[name] = swap_module(mod, mapping)
+            convert(mod, mapping, dtype)
+        reassign[name] = swap_module(mod, mapping, dtype)
 
     for key, value in reassign.items():
         module._modules[key] = value
 
-def swap_module(mod, mapping):
+def swap_module(mod, mapping, dtype=torch.qint8):
     r"""Swaps the module if it has a quantized counterpart and it has an
     `observer` attached.
 
@@ -331,7 +332,14 @@
     new_mod = mod
     if hasattr(mod, 'qconfig') and mod.qconfig is not None:
         if type(mod) in mapping:
-            new_mod = mapping[type(mod)].from_float(mod)
+            supported_scalar_types = [torch.qint8, torch.float16]
+            if dtype not in supported_scalar_types:
+                raise RuntimeError('Unsupported dtype: {}'.format(dtype))
+            if dtype == torch.qint8:
+                new_mod = mapping[type(mod)].from_float(mod)
+            elif dtype == torch.float16:
+                # We want to support float16 dynamic quantization
+                new_mod = mapping[type(mod)].from_float(mod, dtype)
     return new_mod
 
 def dump_tensor(mod, target_dict, prefix=""):