| import warnings |
| |
| from typing import List, Optional, Tuple |
| |
| import torch |
| from torch import _VF, Tensor # noqa: F401 |
| from torch.nn.utils.rnn import PackedSequence |
| |
| |
| class QuantizedLinear(torch.jit.ScriptModule): |
| __constants__ = ["scale", "zero_point"] |
| |
| def __init__(self, other): |
| super().__init__() |
| warnings.warn( |
| "torch.jit.QuantizedLinear is deprecated and will be removed in an upcoming " |
| "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.Linear instead." |
| ) |
| |
| self.in_features = other.in_features |
| self.out_features = other.out_features |
| # Quantize weight and discard the original |
| ( |
| self.weight, |
| self.col_offsets, |
| self.scale, |
| self.zero_point, |
| ) = torch.fbgemm_linear_quantize_weight( |
| other.weight.clone(memory_format=torch.contiguous_format).float() |
| ) |
| self.weight = torch.nn.Parameter(self.weight, requires_grad=False) |
| self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False) |
| assert other.bias is not None, "QuantizedLinear requires a bias" |
| self.bias = torch.nn.Parameter( |
| other.bias.clone(memory_format=torch.contiguous_format).float(), |
| requires_grad=False, |
| ) |
| |
| self.register_buffer( |
| "packed_tensor_ptr", |
| torch.fbgemm_pack_quantized_matrix( |
| self.weight.clone(memory_format=torch.contiguous_format) |
| ), |
| ) |
| |
| @torch.jit.script_method |
| def _unpack(self): |
| self.packed_tensor_ptr.set_(torch.fbgemm_pack_quantized_matrix(self.weight)) |
| |
| @torch.jit.script_method |
| def _pack(self): |
| self.packed_tensor_ptr.set_( |
| torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach() |
| ) |
| |
| @torch.jit.script_method |
| def forward(self, input): |
| out = torch.fbgemm_linear_int8_weight_fp32_activation( |
| input.float(), |
| self.weight, |
| self.packed_tensor_ptr, |
| self.col_offsets, |
| self.scale, |
| self.zero_point, |
| self.bias, |
| ) |
| return out.to(input.dtype) |
| |
| def extra_repr(self): |
| repr = ( |
| "in_features={in_features}, out_features={out_features}, " |
| "scale={scale}, zero_point={zero_point}".format(**self.__dict__) |
| ) |
| return repr |
| |
| |
| # FP16 weights |
| class QuantizedLinearFP16(torch.jit.ScriptModule): |
| def __init__(self, other): |
| super().__init__() |
| warnings.warn( |
| "torch.jit.QuantizedLinearFP16 is deprecated and will be removed in an upcoming " |
| "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.Linear instead." |
| ) |
| self.in_features = other.in_features |
| self.out_features = other.out_features |
| self.original_weight = other.weight |
| self.weight = torch.fbgemm_pack_gemm_matrix_fp16( |
| other.weight.clone(memory_format=torch.contiguous_format).float() |
| ) |
| assert other.bias is not None, "QuantizedLinearFP16 requires a bias" |
| self.bias = torch.nn.Parameter( |
| other.bias.clone(memory_format=torch.contiguous_format).float(), |
| requires_grad=False, |
| ) |
| self.register_buffer("packed_weight", self.weight) |
| |
| @torch.jit.script_method |
| def _unpack(self): |
| self.packed_weight.set_( |
| torch.fbgemm_pack_gemm_matrix_fp16(self.original_weight) |
| ) |
| |
| @torch.jit.script_method |
| def _pack(self): |
| self.packed_weight.set_( |
| torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach() |
| ) |
| |
| @torch.jit.script_method |
| def forward(self, input): |
| out = torch.fbgemm_linear_fp16_weight_fp32_activation( |
| input.float(), self.packed_weight, self.bias |
| ) |
| return out |
| |
| def extra_repr(self): |
| repr = "in_features={in_features}, out_features={out_features}, ".format( |
| **self.__dict__ |
| ) |
| return repr |
| |
| |
| # Quantized RNN cell implementations |
| class QuantizedRNNCellBase(torch.jit.ScriptModule): |
| __constants__ = [ |
| "input_size", |
| "hidden_size", |
| "bias", |
| "scale_hh", |
| "scale_ih", |
| "zero_point_ih", |
| "zero_point_hh", |
| ] |
| |
| def __init__(self, other): |
| super().__init__() |
| warnings.warn( |
| "torch.jit.QuantizedRNNCellBase is deprecated and will be removed in an upcoming " |
| "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.RNNCell instead." |
| ) |
| |
| self.input_size = other.input_size |
| self.hidden_size = other.hidden_size |
| self.bias = other.bias |
| if not self.bias: |
| raise ValueError("Quantized RNN cells require bias terms") |
| |
| ( |
| weight_ih, |
| col_offsets_ih, |
| self.scale_ih, |
| self.zero_point_ih, |
| ) = torch.fbgemm_linear_quantize_weight( |
| other.weight_ih.clone(memory_format=torch.contiguous_format).float() |
| ) |
| self.register_buffer("weight_ih", weight_ih) |
| self.register_buffer("col_offsets_ih", col_offsets_ih) |
| ( |
| weight_hh, |
| col_offsets_hh, |
| self.scale_hh, |
| self.zero_point_hh, |
| ) = torch.fbgemm_linear_quantize_weight( |
| other.weight_hh.clone(memory_format=torch.contiguous_format).float() |
| ) |
| self.register_buffer("weight_hh", weight_hh) |
| self.register_buffer("col_offsets_hh", col_offsets_hh) |
| |
| packed_ih = torch.fbgemm_pack_quantized_matrix(self.weight_ih) |
| self.register_buffer("packed_ih", packed_ih) |
| packed_hh = torch.fbgemm_pack_quantized_matrix(self.weight_hh) |
| self.register_buffer("packed_hh", packed_hh) |
| |
| self.bias_ih = torch.nn.Parameter( |
| other.bias_ih.clone(memory_format=torch.contiguous_format).float(), |
| requires_grad=False, |
| ) |
| self.bias_hh = torch.nn.Parameter( |
| other.bias_hh.clone(memory_format=torch.contiguous_format).float(), |
| requires_grad=False, |
| ) |
| |
| def extra_repr(self): |
| s = "{input_size}, {hidden_size}" |
| if "bias" in self.__dict__ and self.bias is not True: |
| s += ", bias={bias}" |
| if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": |
| s += ", nonlinearity={nonlinearity}" |
| return s.format(**self.__dict__) |
| |
| @torch.jit.script_method |
| def check_forward_input(self, input): |
| if input.size(1) != self.input_size: |
| raise RuntimeError( |
| f"input has inconsistent input_size: got {input.size(1)}, expected {self.input_size}" |
| ) |
| |
| @torch.jit.script_method |
| def check_forward_hidden( |
| self, input: Tensor, hx: Tensor, hidden_label: str = "" |
| ) -> None: |
| if input.size(0) != hx.size(0): |
| raise RuntimeError( |
| f"Input batch size {input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}" |
| ) |
| |
| if hx.size(1) != self.hidden_size: |
| raise RuntimeError( |
| f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}" |
| ) |
| |
| # TODO: for some reason weak_script_method causes a destruction of the |
| # module to occur, which in turn frees the packed_ih object via its DataPtr |
| # deleter. This is bizarre and should probably get fixed. |
| # @torch._jit_internal.weak_script_method |
| @torch.jit.script_method |
| def _unpack(self): |
| self.packed_ih.set_(torch.fbgemm_pack_quantized_matrix(self.weight_ih)) |
| self.packed_hh.set_(torch.fbgemm_pack_quantized_matrix(self.weight_hh)) |
| |
| # @torch._jit_internal.weak_script_method |
| @torch.jit.script_method |
| def _pack(self): |
| self.packed_ih.set_( |
| torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach() |
| ) |
| self.packed_hh.set_( |
| torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach() |
| ) |
| |
| |
| class QuantizedRNNCell(QuantizedRNNCellBase): |
| __constants__ = [ |
| "input_size", |
| "hidden_size", |
| "bias", |
| "scale_hh", |
| "scale_ih", |
| "zero_point_ih", |
| "zero_point_hh", |
| "nonlinearity", |
| ] |
| |
| def __init__(self, other): |
| super().__init__(other) |
| warnings.warn( |
| "torch.jit.QuantizedRNNCell is deprecated and will be removed in an upcoming " |
| "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.RNNCell instead." |
| ) |
| self.nonlinearity = other.nonlinearity |
| |
| @torch.jit.script_method |
| def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: |
| self.check_forward_input(input) |
| if hx is None: |
| hx = torch.zeros( |
| input.size(0), self.hidden_size, dtype=input.dtype, device=input.device |
| ) |
| self.check_forward_hidden(input, hx, "") |
| if self.nonlinearity == "tanh": |
| ret = _VF.quantized_rnn_tanh_cell( |
| input, |
| hx, |
| self.weight_ih, |
| self.weight_hh, |
| self.bias_ih, |
| self.bias_hh, |
| self.packed_ih, |
| self.packed_hh, |
| self.col_offsets_ih, |
| self.col_offsets_hh, |
| self.scale_ih, |
| self.scale_hh, |
| self.zero_point_ih, |
| self.zero_point_hh, |
| ) |
| elif self.nonlinearity == "relu": |
| ret = _VF.quantized_rnn_relu_cell( |
| input, |
| hx, |
| self.weight_ih, |
| self.weight_hh, |
| self.bias_ih, |
| self.bias_hh, |
| self.packed_ih, |
| self.packed_hh, |
| self.col_offsets_ih, |
| self.col_offsets_hh, |
| self.scale_ih, |
| self.scale_hh, |
| self.zero_point_ih, |
| self.zero_point_hh, |
| ) |
| else: |
| ret = input # TODO: remove when jit supports exception flow |
| raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") |
| return ret |
| |
| |
| class QuantizedLSTMCell(QuantizedRNNCellBase): |
| def __init__(self, other): |
| super().__init__(other) |
| warnings.warn( |
| "torch.jit.QuantizedLSTMCell is deprecated and will be removed in an upcoming " |
| "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.LSTMCell instead." |
| ) |
| |
| @torch.jit.script_method |
| def forward( |
| self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None |
| ) -> Tuple[Tensor, Tensor]: |
| self.check_forward_input(input) |
| if hx is None: |
| zeros = torch.zeros( |
| input.size(0), self.hidden_size, dtype=input.dtype, device=input.device |
| ) |
| hx = (zeros, zeros) |
| self.check_forward_hidden(input, hx[0], "[0]") |
| self.check_forward_hidden(input, hx[1], "[1]") |
| return _VF.quantized_lstm_cell( |
| input, |
| hx, |
| self.weight_ih, |
| self.weight_hh, |
| self.bias_ih, |
| self.bias_hh, |
| self.packed_ih, |
| self.packed_hh, |
| self.col_offsets_ih, |
| self.col_offsets_hh, |
| self.scale_ih, |
| self.scale_hh, |
| self.zero_point_ih, |
| self.zero_point_hh, |
| ) |
| |
| |
| class QuantizedGRUCell(QuantizedRNNCellBase): |
| def __init__(self, other): |
| super().__init__(other) |
| warnings.warn( |
| "torch.jit.QuantizedGRUCell is deprecated and will be removed in an upcoming " |
| "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.GRUCell instead." |
| ) |
| |
| @torch.jit.script_method |
| def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: |
| self.check_forward_input(input) |
| if hx is None: |
| hx = torch.zeros( |
| input.size(0), self.hidden_size, dtype=input.dtype, device=input.device |
| ) |
| self.check_forward_hidden(input, hx, "") |
| return _VF.quantized_gru_cell( |
| input, |
| hx, |
| self.weight_ih, |
| self.weight_hh, |
| self.bias_ih, |
| self.bias_hh, |
| self.packed_ih, |
| self.packed_hh, |
| self.col_offsets_ih, |
| self.col_offsets_hh, |
| self.scale_ih, |
| self.scale_hh, |
| self.zero_point_ih, |
| self.zero_point_hh, |
| ) |
| |
| |
| def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: |
| return tensor.index_select(dim, permutation) |
| |
| |
| class QuantizedRNNBase(torch.jit.ScriptModule): |
| __constants__ = [ |
| "mode", |
| "input_size", |
| "hidden_size", |
| "num_layers", |
| "bias", |
| "batch_first", |
| "dropout", |
| "bidirectional", |
| "dtype", |
| ] |
| |
| def __init__(self, other, dtype=torch.int8): |
| super().__init__() |
| warnings.warn( |
| "torch.jit.QuantizedRNNBase is deprecated and will be removed in an upcoming " |
| "PyTorch release. Please use the torch.ao.nn.quantized.dynamic instead." |
| ) |
| self.mode = other.mode |
| self.input_size = other.input_size |
| self.hidden_size = other.hidden_size |
| self.num_layers = other.num_layers |
| self.bias = other.bias |
| self.batch_first = other.batch_first |
| if self.mode != "GRU": |
| assert not self.batch_first |
| self.dropout = other.dropout |
| self.bidirectional = other.bidirectional |
| num_directions = 2 if self.bidirectional else 1 |
| self.dtype = dtype |
| |
| assert self.bias |
| |
| # TODO: support more than just LSTM |
| if self.mode != "LSTM" and self.mode != "GRU": |
| raise RuntimeError("Only LSTM or GRU is supported for QuantizedRNN") |
| |
| if dtype != torch.int8 and dtype != torch.float16: |
| raise RuntimeError(f"Unsupported dtype: {dtype}") |
| |
| self.all_weights = [] |
| for layer in range(self.num_layers): |
| for direction in range(num_directions): |
| layer_input_size = ( |
| self.input_size if layer == 0 else self.hidden_size * num_directions |
| ) |
| |
| suffix = "_reverse" if direction == 1 else "" |
| |
| def get_weight_bias(ihhh): |
| weight_name = f"weight_{ihhh}_l{layer}{suffix}" |
| bias_name = f"bias_{ihhh}_l{layer}{suffix}" |
| |
| weight = getattr(other, weight_name) |
| bias = getattr(other, bias_name) |
| return weight, bias |
| |
| weight_ih, bias_ih = get_weight_bias("ih") |
| weight_hh, bias_hh = get_weight_bias("hh") |
| |
| if dtype == torch.int8: |
| cell_params = torch.ops.quantized.make_quantized_cell_params( |
| weight_ih, weight_hh, bias_ih, bias_hh |
| ) |
| else: |
| packed_ih = torch.ops.quantized.linear_prepack_fp16( |
| weight_ih.float(), bias_ih |
| ) |
| packed_hh = torch.ops.quantized.linear_prepack_fp16( |
| weight_hh.float(), bias_hh |
| ) |
| |
| cell_params = torch.ops.quantized.make_quantized_cell_params_fp16( |
| packed_ih, packed_hh |
| ) |
| |
| setattr(self, f"cell_params_{layer}_{suffix}", cell_params) |
| self.all_weights.append(cell_params) |
| |
| @torch.jit.script_method |
| def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: |
| expected_input_dim = 2 if batch_sizes is not None else 3 |
| if input.dim() != expected_input_dim: |
| raise RuntimeError( |
| f"input must have {expected_input_dim} dimensions, got {input.dim()}" |
| ) |
| if self.input_size != input.size(-1): |
| raise RuntimeError( |
| f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}" |
| ) |
| |
| @torch.jit.script_method |
| def get_expected_hidden_size( |
| self, input: Tensor, batch_sizes: Optional[Tensor] |
| ) -> Tuple[int, int, int]: |
| if batch_sizes is not None: |
| mini_batch = int(batch_sizes[0]) |
| else: |
| mini_batch = input.size(0) if self.batch_first else input.size(1) |
| num_directions = 2 if self.bidirectional else 1 |
| expected_hidden_size = ( |
| self.num_layers * num_directions, |
| mini_batch, |
| self.hidden_size, |
| ) |
| return expected_hidden_size |
| |
| @torch.jit.script_method |
| def check_hidden_size( |
| self, |
| hx: Tensor, |
| expected_hidden_size: Tuple[int, int, int], |
| msg: str = "Expected hidden size {}, got {}", |
| ) -> None: |
| if hx.size() != expected_hidden_size: |
| raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) |
| |
| @torch.jit.script_method |
| def check_forward_args( |
| self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] |
| ) -> None: |
| self.check_input(input, batch_sizes) |
| expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) |
| self.check_hidden_size( |
| hidden, expected_hidden_size, msg="Expected hidden size {}, got {}" |
| ) |
| |
| @torch.jit.script_method |
| def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor: |
| if permutation is None: |
| return hx |
| return apply_permutation(hx, permutation) |
| |
| |
| class QuantizedLSTM(QuantizedRNNBase): |
| __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} |
| |
| def __init__(self, other, dtype): |
| super().__init__(other, dtype) |
| warnings.warn( |
| "torch.jit.QuantizedLSTM is deprecated and will be removed in an upcoming " |
| "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.LSTM instead." |
| ) |
| |
| @torch.jit.script_method |
| def forward_impl( |
| self, |
| input: Tensor, |
| hx: Optional[Tuple[Tensor, Tensor]], |
| batch_sizes: Optional[Tensor], |
| max_batch_size: int, |
| sorted_indices: Optional[Tensor], |
| ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
| if hx is None: |
| num_directions = 2 if self.bidirectional else 1 |
| zeros = torch.zeros( |
| self.num_layers * num_directions, |
| max_batch_size, |
| self.hidden_size, |
| dtype=input.dtype, |
| device=input.device, |
| ) |
| hx = (zeros, zeros) |
| else: |
| # Each batch of the hidden state should match the input sequence that |
| # the user believes he/she is passing in. |
| hx = self.permute_hidden(hx, sorted_indices) |
| |
| self.check_forward_args(input, hx, batch_sizes) |
| assert batch_sizes is None |
| result = torch.quantized_lstm( |
| input, |
| hx, |
| self.all_weights, |
| self.bias, |
| self.num_layers, |
| float(self.dropout), |
| self.training, |
| self.bidirectional, |
| self.batch_first, |
| dtype=self.dtype, |
| use_dynamic=False, |
| ) |
| output = result[0] |
| hidden = result[1:] |
| |
| return output, hidden |
| |
| @torch.jit.script_method |
| def forward_tensor( |
| self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None |
| ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
| batch_sizes = None |
| max_batch_size = input.size(0) if self.batch_first else input.size(1) |
| sorted_indices = None |
| unsorted_indices = None |
| |
| output, hidden = self.forward_impl( |
| input, hx, batch_sizes, max_batch_size, sorted_indices |
| ) |
| |
| return output, self.permute_hidden(hidden, unsorted_indices) |
| |
| @torch.jit.script_method |
| def forward_packed( |
| self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None |
| ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: |
| input_, batch_sizes, sorted_indices, unsorted_indices = input |
| max_batch_size = int(batch_sizes[0]) |
| |
| output, hidden = self.forward_impl( |
| input_, hx, batch_sizes, max_batch_size, sorted_indices |
| ) |
| |
| output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) |
| return output, self.permute_hidden(hidden, unsorted_indices) |
| |
| @torch.jit.script_method |
| def permute_hidden( |
| self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor] |
| ) -> Tuple[Tensor, Tensor]: |
| if permutation is None: |
| return hx |
| return apply_permutation(hx[0], permutation), apply_permutation( |
| hx[1], permutation |
| ) |
| |
| @torch.jit.script_method |
| def check_forward_args( |
| self, |
| input: Tensor, |
| hidden: Tuple[Tensor, Tensor], |
| batch_sizes: Optional[Tensor], |
| ) -> None: |
| self.check_input(input, batch_sizes) |
| expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) |
| |
| self.check_hidden_size( |
| hidden[0], expected_hidden_size, "Expected hidden[0] size {}, got {}" |
| ) |
| self.check_hidden_size( |
| hidden[1], expected_hidden_size, "Expected hidden[1] size {}, got {}" |
| ) |
| |
| def forward(self, input, hx=None): |
| if isinstance(input, PackedSequence): |
| return self.forward_packed(input, hx) |
| else: |
| return self.forward_tensor(input, hx) |
| |
| |
| class QuantizedGRU(QuantizedRNNBase): |
| __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} |
| |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| warnings.warn( |
| "torch.jit.QuantizedGRU is deprecated and will be removed in an upcoming " |
| "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.GRU instead." |
| ) |
| |
| @torch.jit.script_method |
| def forward_impl( |
| self, |
| input: Tensor, |
| hx: Optional[Tensor], |
| batch_sizes: Optional[Tensor], |
| max_batch_size: int, |
| sorted_indices: Optional[Tensor], |
| ) -> Tuple[Tensor, Tensor]: |
| if hx is None: |
| num_directions = 2 if self.bidirectional else 1 |
| hx = torch.zeros( |
| self.num_layers * num_directions, |
| max_batch_size, |
| self.hidden_size, |
| dtype=input.dtype, |
| device=input.device, |
| ) |
| else: |
| # Each batch of the hidden state should match the input sequence that |
| # the user believes he/she is passing in. |
| hx = self.permute_hidden(hx, sorted_indices) |
| |
| self.check_forward_args(input, hx, batch_sizes) |
| if batch_sizes is None: |
| result = torch.quantized_gru( |
| input, |
| hx, |
| self.all_weights, |
| self.bias, |
| self.num_layers, |
| float(self.dropout), |
| self.training, |
| self.bidirectional, |
| self.batch_first, |
| ) |
| else: |
| result = torch.quantized_gru( |
| input, |
| batch_sizes, |
| hx, |
| self.all_weights, |
| self.bias, |
| self.num_layers, |
| float(self.dropout), |
| self.training, |
| self.bidirectional, |
| ) |
| |
| output = result[0] |
| hidden = result[1] |
| |
| return output, hidden |
| |
| @torch.jit.script_method |
| def forward_tensor( |
| self, input: Tensor, hx: Optional[Tensor] = None |
| ) -> Tuple[Tensor, Tensor]: |
| batch_sizes = None |
| max_batch_size = input.size(0) if self.batch_first else input.size(1) |
| sorted_indices = None |
| unsorted_indices = None |
| |
| output, hidden = self.forward_impl( |
| input, hx, batch_sizes, max_batch_size, sorted_indices |
| ) |
| return output, self.permute_hidden(hidden, unsorted_indices) |
| |
| @torch.jit.script_method |
| def forward_packed( |
| self, input: PackedSequence, hx: Optional[Tensor] = None |
| ) -> Tuple[PackedSequence, Tensor]: |
| input_, batch_sizes, sorted_indices, unsorted_indices = input |
| max_batch_size = int(batch_sizes[0]) |
| |
| output, hidden = self.forward_impl( |
| input_, hx, batch_sizes, max_batch_size, sorted_indices |
| ) |
| |
| output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) |
| return output, self.permute_hidden(hidden, unsorted_indices) |
| |
| def forward(self, input, hx=None): |
| if isinstance(input, PackedSequence): |
| return self.forward_packed(input, hx) |
| else: |
| return self.forward_tensor(input, hx) |
| |
| |
| def quantize_rnn_cell_modules(module): |
| warnings.warn( |
| "quantize_rnn_cell_modules function has been deprecated. " |
| "Please use torch.ao.quantization.quantize_dynamic API instead." |
| ) |
| reassign = {} |
| for name, mod in module.named_modules(): |
| if mod is module: |
| continue |
| new_mod = quantize_rnn_cell_modules(mod) |
| if new_mod is not mod: |
| reassign[name] = new_mod |
| for name, mod in reassign.items(): |
| setattr(module, name, mod) |
| if isinstance(module, torch.nn.LSTMCell): |
| return QuantizedLSTMCell(module) |
| if isinstance(module, torch.nn.GRUCell): |
| return QuantizedGRUCell(module) |
| if isinstance(module, torch.nn.RNNCell): |
| return QuantizedRNNCell(module) |
| return module |
| |
| |
| def quantize_linear_modules(module, dtype=torch.int8): |
| warnings.warn( |
| "quantize_linear_modules function has been deprecated. " |
| "Please use torch.ao.quantization.quantize_dynamic API instead." |
| ) |
| |
| reassign = {} |
| for name, mod in module.named_modules(): |
| if mod is module: |
| continue |
| new_mod = quantize_linear_modules(mod, dtype) |
| if new_mod is not mod: |
| reassign[name] = new_mod |
| |
| for name, mod in reassign.items(): |
| setattr(module, name, mod) |
| if isinstance(module, torch.nn.Linear): |
| if dtype == torch.int8: |
| return QuantizedLinear(module) |
| elif dtype == torch.float16: |
| return QuantizedLinearFP16(module) |
| else: |
| raise RuntimeError(f"Unsupported dtype: {dtype}") |
| return module |
| |
| |
| def quantize_rnn_modules(module, dtype=torch.int8): |
| warnings.warn( |
| "quantize_rnn_modules function has been deprecated. " |
| "Please use torch.ao.quantization.quantize_dynamic API instead." |
| ) |
| reassign = {} |
| for name, mod in module.named_modules(): |
| if mod is module: |
| continue |
| new_mod = quantize_rnn_modules(mod, dtype) |
| if new_mod is not mod: |
| reassign[name] = new_mod |
| |
| for name, mod in reassign.items(): |
| setattr(module, name, mod) |
| if isinstance(module, torch.nn.LSTM): |
| if dtype != torch.int8 and dtype != torch.float16: |
| raise RuntimeError(f"Unsupported dtype: {dtype}") |
| return QuantizedLSTM(module, dtype) |
| if isinstance(module, torch.nn.GRU): |
| return QuantizedGRU(module) |
| return module |