| """ |
| Utils shared by different modes of quantization (eager/graph) |
| """ |
| import warnings |
| |
| import torch |
| from .quant_type import QuantType, quant_type_to_str |
| |
| def get_combined_dict(default_dict, additional_dict): |
| d = default_dict.copy() |
| d.update(additional_dict) |
| return d |
| |
| def is_per_tensor(qscheme): |
| return qscheme == torch.per_tensor_affine or \ |
| qscheme == torch.per_tensor_symmetric |
| |
| def is_per_channel(qscheme): |
| return qscheme in [torch.per_channel_affine, |
| torch.per_channel_affine_float_qparams, |
| torch.per_channel_symmetric] |
| |
| def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig): |
| """ Get the observed/quantized custom module class that we need |
| to swap `custom_module` to |
| Input: |
| custom_module: input, can be an instance of either a float or observed custom module |
| custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping |
| qconfig: qconfig configured for the custom module |
| |
| Output: |
| corresponding observed/quantized custom module class for input custom module instance |
| """ |
| quant_type = get_quant_type(qconfig) |
| quant_type_str = quant_type_to_str(quant_type) |
| class_mapping = custom_module_class_mapping.get(quant_type_str, {}) |
| assert type(custom_module) in class_mapping, "did not find corresponding observed " \ |
| "module class for {} in mapping: {}".format(type(custom_module), class_mapping) |
| return class_mapping[type(custom_module)] |
| |
| def activation_dtype(qconfig): |
| assert qconfig is not None |
| activation = qconfig.activation() |
| return activation.dtype |
| |
| def weight_dtype(qconfig): |
| assert qconfig is not None |
| weight = qconfig.weight() |
| return weight.dtype |
| |
| def activation_is_statically_quantized(qconfig): |
| """ Given a qconfig, decide if the activation needs to be |
| quantized or not, this includes quantizing to quint8, qint8 and float16 |
| """ |
| return activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16] |
| |
| def activation_is_int8_quantized(qconfig): |
| """ Given a qconfig, decide if the activation needs to be |
| quantized to int8 or not, this includes quantizing to quint8, qint8 |
| """ |
| return activation_dtype(qconfig) in [torch.quint8, torch.qint8] |
| |
| def weight_is_quantized(qconfig): |
| """ Given a qconfig, decide if the weight needs to be |
| quantized or not |
| """ |
| return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16] |
| |
| def weight_is_statically_quantized(qconfig): |
| """ Given a qconfig, decide if the weight needs to be statically |
| quantized or not |
| """ |
| return weight_dtype(qconfig) in [torch.quint8, torch.qint8] |
| |
| def get_qconfig_dtypes(qconfig): |
| r""" returns the qconfig tuple for qconfig: |
| (activation_dtype, weight_dtype, activation_compute_dtype) |
| """ |
| assert qconfig is not None |
| activation = qconfig.activation() |
| weight = qconfig.weight() |
| compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None |
| return (activation.dtype, weight.dtype, compute_dtype) |
| |
| def get_quant_type(qconfig): |
| assert qconfig is not None |
| activation = qconfig.activation() |
| weight = qconfig.weight() |
| static_dtypes = [torch.quint8, torch.qint8] |
| if weight.dtype in static_dtypes: |
| if activation.dtype in static_dtypes: |
| return QuantType.STATIC |
| elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes: |
| return QuantType.DYNAMIC |
| else: |
| return QuantType.WEIGHT_ONLY |
| |
| if weight.dtype == torch.float16: |
| if activation.dtype == torch.float: |
| return QuantType.DYNAMIC |
| elif activation.dtype == torch.float16: |
| return QuantType.STATIC |
| |
| raise Exception("Unrecognized dtype combination in get_quant_type: activation({})," |
| "weight({})".format(activation.dtype, weight.dtype)) |
| |
| def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool: |
| """ Checks if the given minimum and maximum values are valid, meaning that |
| they exist and the min value is less than the max value. |
| """ |
| if min_val.numel() == 0 or max_val.numel() == 0: |
| warnings.warn( |
| "must run observer before calling calculate_qparams. " + |
| "Returning default values." |
| ) |
| return False |
| |
| if min_val.dim() == 0 or max_val.dim() == 0: |
| if min_val == float("inf") and max_val == float("-inf"): |
| warnings.warn( |
| "must run observer before calling calculate_qparams. " + |
| "Returning default values." |
| ) |
| |
| return False |
| |
| assert min_val <= max_val, "min {} should be less than max {}".format( |
| min_val, max_val |
| ) |
| else: |
| assert torch.all( |
| min_val <= max_val |
| ), "min {} should be less than max {}".format(min_val, max_val) |
| |
| return True |