| import numbers |
| from typing import Optional, Tuple |
| import warnings |
| |
| import torch |
| from torch import Tensor |
| |
| """ |
| We will recreate all the RNN modules as we require the modules to be decomposed |
| into its building blocks to be able to observe. |
| """ |
| |
| class LSTMCell(torch.nn.Module): |
| r"""A quantizable long short-term memory (LSTM) cell. |
| |
| For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell` |
| |
| Examples:: |
| |
| >>> import torch.nn.quantizable as nnqa |
| >>> rnn = nnqa.LSTMCell(10, 20) |
| >>> input = torch.randn(6, 10) |
| >>> hx = torch.randn(3, 20) |
| >>> cx = torch.randn(3, 20) |
| >>> output = [] |
| >>> for i in range(6): |
| ... hx, cx = rnn(input[i], (hx, cx)) |
| ... output.append(hx) |
| """ |
| _FLOAT_MODULE = torch.nn.LSTMCell |
| |
| def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True, |
| device=None, dtype=None) -> None: |
| factory_kwargs = {'device': device, 'dtype': dtype} |
| super().__init__() |
| self.input_size = input_dim |
| self.hidden_size = hidden_dim |
| self.bias = bias |
| |
| self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs) |
| self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs) |
| self.gates = torch.nn.quantized.FloatFunctional() |
| |
| self.fgate_cx = torch.nn.quantized.FloatFunctional() |
| self.igate_cgate = torch.nn.quantized.FloatFunctional() |
| self.fgate_cx_igate_cgate = torch.nn.quantized.FloatFunctional() |
| |
| self.ogate_cy = torch.nn.quantized.FloatFunctional() |
| |
| def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: |
| if hidden is None or hidden[0] is None or hidden[1] is None: |
| hidden = self.initialize_hidden(x.shape[0], x.is_quantized) |
| hx, cx = hidden |
| |
| igates = self.igates(x) |
| hgates = self.hgates(hx) |
| gates = self.gates.add(igates, hgates) |
| |
| input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) |
| |
| input_gate = torch.sigmoid(input_gate) |
| forget_gate = torch.sigmoid(forget_gate) |
| cell_gate = torch.tanh(cell_gate) |
| out_gate = torch.sigmoid(out_gate) |
| |
| fgate_cx = self.fgate_cx.mul(forget_gate, cx) |
| igate_cgate = self.igate_cgate.mul(input_gate, cell_gate) |
| fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate) |
| cy = fgate_cx_igate_cgate |
| |
| tanh_cy = torch.tanh(cy) |
| hy = self.ogate_cy.mul(out_gate, tanh_cy) |
| return hy, cy |
| |
| def initialize_hidden(self, batch_size: int, is_quantized: bool = False) -> Tuple[Tensor, Tensor]: |
| h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros((batch_size, self.hidden_size)) |
| if is_quantized: |
| h = torch.quantize_per_tensor(h, scale=1.0, zero_point=0, dtype=torch.quint8) |
| c = torch.quantize_per_tensor(c, scale=1.0, zero_point=0, dtype=torch.quint8) |
| return h, c |
| |
| def _get_name(self): |
| return 'QuantizableLSTMCell' |
| |
| @classmethod |
| def from_params(cls, wi, wh, bi=None, bh=None): |
| """Uses the weights and biases to create a new LSTM cell. |
| |
| Args: |
| wi, wh: Weights for the input and hidden layers |
| bi, bh: Biases for the input and hidden layers |
| """ |
| assert (bi is None) == (bh is None) # Either both None or both have values |
| input_size = wi.shape[1] |
| hidden_size = wh.shape[1] |
| cell = cls(input_dim=input_size, hidden_dim=hidden_size, |
| bias=(bi is not None)) |
| cell.igates.weight = torch.nn.Parameter(wi) |
| if bi is not None: |
| cell.igates.bias = torch.nn.Parameter(bi) |
| cell.hgates.weight = torch.nn.Parameter(wh) |
| if bh is not None: |
| cell.hgates.bias = torch.nn.Parameter(bh) |
| return cell |
| |
| @classmethod |
| def from_float(cls, other): |
| assert type(other) == cls._FLOAT_MODULE |
| assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'" |
| observed = cls.from_params(other.weight_ih, other.weight_hh, |
| other.bias_ih, other.bias_hh) |
| observed.qconfig = other.qconfig |
| observed.igates.qconfig = other.qconfig |
| observed.hgates.qconfig = other.qconfig |
| return observed |
| |
| |
| class _LSTMSingleLayer(torch.nn.Module): |
| r"""A single one-directional LSTM layer. |
| |
| The difference between a layer and a cell is that the layer can process a |
| sequence, while the cell only expects an instantaneous value. |
| """ |
| def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True, |
| device=None, dtype=None) -> None: |
| factory_kwargs = {'device': device, 'dtype': dtype} |
| super().__init__() |
| self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs) |
| |
| def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): |
| result = [] |
| for xx in x: |
| hidden = self.cell(xx, hidden) |
| result.append(hidden[0]) # type: ignore[index] |
| result_tensor = torch.stack(result, 0) |
| return result_tensor, hidden |
| |
| @classmethod |
| def from_params(cls, *args, **kwargs): |
| cell = LSTMCell.from_params(*args, **kwargs) |
| layer = cls(cell.input_size, cell.hidden_size, cell.bias) |
| layer.cell = cell |
| return layer |
| |
| |
| class _LSTMLayer(torch.nn.Module): |
| r"""A single bi-directional LSTM layer.""" |
| def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True, |
| batch_first: bool = False, bidirectional: bool = False, |
| device=None, dtype=None) -> None: |
| factory_kwargs = {'device': device, 'dtype': dtype} |
| super().__init__() |
| self.batch_first = batch_first |
| self.bidirectional = bidirectional |
| self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs) |
| if self.bidirectional: |
| self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs) |
| |
| def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): |
| if self.batch_first: |
| x = x.transpose(0, 1) |
| if hidden is None: |
| hx_fw, cx_fw = (None, None) |
| else: |
| hx_fw, cx_fw = hidden |
| hidden_bw: Optional[Tuple[Tensor, Tensor]] = None |
| if self.bidirectional: |
| if hx_fw is None: |
| hx_bw = None |
| else: |
| hx_bw = hx_fw[1] |
| hx_fw = hx_fw[0] |
| if cx_fw is None: |
| cx_bw = None |
| else: |
| cx_bw = cx_fw[1] |
| cx_fw = cx_fw[0] |
| if hx_bw is not None and cx_bw is not None: |
| hidden_bw = hx_bw, cx_bw |
| if hx_fw is None and cx_fw is None: |
| hidden_fw = None |
| else: |
| hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(cx_fw) |
| result_fw, hidden_fw = self.layer_fw(x, hidden_fw) |
| |
| if hasattr(self, 'layer_bw') and self.bidirectional: |
| x_reversed = x.flip(0) |
| result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw) |
| result_bw = result_bw.flip(0) |
| |
| result = torch.cat([result_fw, result_bw], result_fw.dim() - 1) |
| if hidden_fw is None and hidden_bw is None: |
| h = None |
| c = None |
| elif hidden_fw is None: |
| (h, c) = torch.jit._unwrap_optional(hidden_bw) |
| elif hidden_bw is None: |
| (h, c) = torch.jit._unwrap_optional(hidden_fw) |
| else: |
| h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item] |
| c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item] |
| else: |
| result = result_fw |
| h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment] |
| |
| if self.batch_first: |
| result.transpose_(0, 1) |
| |
| return result, (h, c) |
| |
| @classmethod |
| def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs): |
| r""" |
| There is no FP equivalent of this class. This function is here just to |
| mimic the behavior of the `prepare` within the `torch.ao.quantization` |
| flow. |
| """ |
| assert hasattr(other, 'qconfig') or (qconfig is not None) |
| |
| input_size = kwargs.get('input_size', other.input_size) |
| hidden_size = kwargs.get('hidden_size', other.hidden_size) |
| bias = kwargs.get('bias', other.bias) |
| batch_first = kwargs.get('batch_first', other.batch_first) |
| bidirectional = kwargs.get('bidirectional', other.bidirectional) |
| |
| layer = cls(input_size, hidden_size, bias, batch_first, bidirectional) |
| layer.qconfig = getattr(other, 'qconfig', qconfig) |
| wi = getattr(other, f'weight_ih_l{layer_idx}') |
| wh = getattr(other, f'weight_hh_l{layer_idx}') |
| bi = getattr(other, f'bias_ih_l{layer_idx}', None) |
| bh = getattr(other, f'bias_hh_l{layer_idx}', None) |
| |
| layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh) |
| |
| if other.bidirectional: |
| wi = getattr(other, f'weight_ih_l{layer_idx}_reverse') |
| wh = getattr(other, f'weight_hh_l{layer_idx}_reverse') |
| bi = getattr(other, f'bias_ih_l{layer_idx}_reverse', None) |
| bh = getattr(other, f'bias_hh_l{layer_idx}_reverse', None) |
| layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh) |
| return layer |
| |
| |
| class LSTM(torch.nn.Module): |
| r"""A quantizable long short-term memory (LSTM). |
| |
| For the description and the argument types, please, refer to :class:`~torch.nn.LSTM` |
| |
| Attributes: |
| layers : instances of the `_LSTMLayer` |
| |
| .. note:: |
| To access the weights and biases, you need to access them per layer. |
| See examples below. |
| |
| Examples:: |
| |
| >>> import torch.nn.quantizable as nnqa |
| >>> rnn = nnqa.LSTM(10, 20, 2) |
| >>> input = torch.randn(5, 3, 10) |
| >>> h0 = torch.randn(2, 3, 20) |
| >>> c0 = torch.randn(2, 3, 20) |
| >>> output, (hn, cn) = rnn(input, (h0, c0)) |
| >>> # To get the weights: |
| >>> # xdoctest: +SKIP |
| >>> print(rnn.layers[0].weight_ih) |
| tensor([[...]]) |
| >>> print(rnn.layers[0].weight_hh) |
| AssertionError: There is no reverse path in the non-bidirectional layer |
| """ |
| _FLOAT_MODULE = torch.nn.LSTM |
| |
| def __init__(self, input_size: int, hidden_size: int, |
| num_layers: int = 1, bias: bool = True, |
| batch_first: bool = False, dropout: float = 0., |
| bidirectional: bool = False, |
| device=None, dtype=None) -> None: |
| factory_kwargs = {'device': device, 'dtype': dtype} |
| super().__init__() |
| self.input_size = input_size |
| self.hidden_size = hidden_size |
| self.num_layers = num_layers |
| self.bias = bias |
| self.batch_first = batch_first |
| self.dropout = float(dropout) |
| self.bidirectional = bidirectional |
| self.training = False # We don't want to train using this module |
| num_directions = 2 if bidirectional else 1 |
| |
| if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \ |
| isinstance(dropout, bool): |
| raise ValueError("dropout should be a number in range [0, 1] " |
| "representing the probability of an element being " |
| "zeroed") |
| if dropout > 0: |
| warnings.warn("dropout option for quantizable LSTM is ignored. " |
| "If you are training, please, use nn.LSTM version " |
| "followed by `prepare` step.") |
| if num_layers == 1: |
| warnings.warn("dropout option adds dropout after all but last " |
| "recurrent layer, so non-zero dropout expects " |
| "num_layers greater than 1, but got dropout={} " |
| "and num_layers={}".format(dropout, num_layers)) |
| |
| layers = [_LSTMLayer(self.input_size, self.hidden_size, |
| self.bias, batch_first=False, |
| bidirectional=self.bidirectional, **factory_kwargs)] |
| for layer in range(1, num_layers): |
| layers.append(_LSTMLayer(self.hidden_size, self.hidden_size, |
| self.bias, batch_first=False, |
| bidirectional=self.bidirectional, |
| **factory_kwargs)) |
| self.layers = torch.nn.ModuleList(layers) |
| |
| def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): |
| if self.batch_first: |
| x = x.transpose(0, 1) |
| |
| max_batch_size = x.size(1) |
| num_directions = 2 if self.bidirectional else 1 |
| if hidden is None: |
| zeros = torch.zeros(num_directions, max_batch_size, |
| self.hidden_size, dtype=torch.float, |
| device=x.device) |
| zeros.squeeze_(0) |
| if x.is_quantized: |
| zeros = torch.quantize_per_tensor(zeros, scale=1.0, |
| zero_point=0, dtype=x.dtype) |
| hxcx = [(zeros, zeros) for _ in range(self.num_layers)] |
| else: |
| hidden_non_opt = torch.jit._unwrap_optional(hidden) |
| if isinstance(hidden_non_opt[0], Tensor): |
| hx = hidden_non_opt[0].reshape(self.num_layers, num_directions, |
| max_batch_size, |
| self.hidden_size).unbind(0) |
| cx = hidden_non_opt[1].reshape(self.num_layers, num_directions, |
| max_batch_size, |
| self.hidden_size).unbind(0) |
| hxcx = [(hx[idx].squeeze_(0), cx[idx].squeeze_(0)) for idx in range(self.num_layers)] |
| else: |
| hxcx = hidden_non_opt |
| |
| hx_list = [] |
| cx_list = [] |
| for idx, layer in enumerate(self.layers): |
| x, (h, c) = layer(x, hxcx[idx]) |
| hx_list.append(torch.jit._unwrap_optional(h)) |
| cx_list.append(torch.jit._unwrap_optional(c)) |
| hx_tensor = torch.stack(hx_list) |
| cx_tensor = torch.stack(cx_list) |
| |
| # We are creating another dimension for bidirectional case |
| # need to collapse it |
| hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1]) |
| cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1]) |
| |
| if self.batch_first: |
| x = x.transpose(0, 1) |
| |
| return x, (hx_tensor, cx_tensor) |
| |
| def _get_name(self): |
| return 'QuantizableLSTM' |
| |
| @classmethod |
| def from_float(cls, other, qconfig=None): |
| assert isinstance(other, cls._FLOAT_MODULE) |
| assert (hasattr(other, 'qconfig') or qconfig) |
| observed = cls(other.input_size, other.hidden_size, other.num_layers, |
| other.bias, other.batch_first, other.dropout, |
| other.bidirectional) |
| observed.qconfig = getattr(other, 'qconfig', qconfig) |
| for idx in range(other.num_layers): |
| observed.layers[idx] = _LSTMLayer.from_float(other, idx, qconfig, |
| batch_first=False) |
| observed.eval() |
| observed = torch.ao.quantization.prepare(observed, inplace=True) |
| return observed |
| |
| @classmethod |
| def from_observed(cls, other): |
| # The whole flow is float -> observed -> quantized |
| # This class does float -> observed only |
| raise NotImplementedError("It looks like you are trying to convert a " |
| "non-quantizable LSTM module. Please, see " |
| "the examples on quantizable LSTMs.") |