| from __future__ import absolute_import, division, print_function, unicode_literals |
| |
| |
| def _find_match(str_list, key_str, postfix): |
| split_str = key_str.split(".") |
| if split_str[-1] == postfix: |
| match_string = "".join(key_str.split(".")[0:-1]) |
| for s2 in str_list: |
| pattern1 = "".join(s2.split(".")[0:-1]) |
| pattern2 = "".join(s2.split(".")[0:-2]) |
| if match_string == pattern1: |
| return s2 |
| if match_string == pattern2: |
| return s2 |
| else: |
| return None |
| |
| |
| def compare_weights(float_dict, quantized_dict): |
| r"""Returns a dict with key corresponding to module names and each entry being |
| a dictionary with two keys 'float' and 'quantized', containing the float and |
| quantized weights. This dict can be used to compare and compute the quantization |
| error of the weights of float and quantized models . |
| |
| Args: |
| float_dict: state dict of the float model |
| quantized_dict: state dict of the quantized model |
| |
| Return: |
| weight_dict: dict with key corresponding to module names and each entry being |
| a dictionary with two keys 'float' and 'quantized', containing the float and |
| quantized weights |
| """ |
| weight_dict = {} |
| for key in quantized_dict: |
| match_key = _find_match(float_dict, key, "weight") |
| if match_key is not None: |
| weight_dict[key] = {} |
| weight_dict[key]["float"] = float_dict[match_key] |
| weight_dict[key]["quantized"] = quantized_dict[key] |
| return weight_dict |