ns for fx: add utils for l2 error and cosine similarity (#61380)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61380
Adds convenience wrappers for l2 error and cosine similarity
to NS utils.
Test Plan:
```
python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_extend_logger_results_with_comparison
```
Imported from OSS
Reviewed By: hx89
Differential Revision: D29600354
fbshipit-source-id: 670c44a44df7f345884cacf26ed3c885edbe9977
diff --git a/test/quantization/fx/test_numeric_suite_fx.py b/test/quantization/fx/test_numeric_suite_fx.py
index 307d0c1..7e4b024 100644
--- a/test/quantization/fx/test_numeric_suite_fx.py
+++ b/test/quantization/fx/test_numeric_suite_fx.py
@@ -42,6 +42,8 @@
)
from torch.quantization.ns.utils import (
compute_sqnr,
+ compute_normalized_l2_error,
+ compute_cosine_similarity,
)
from torch.quantization.ns.mappings import (
get_node_type_to_io_type_map,
@@ -778,6 +780,13 @@
len(results) == results_len,
f"expected len {results_len}, got len {len(results)}")
self.assert_ns_compare_dict_valid(results)
+ extend_logger_results_with_comparison(
+ results, 'a', 'b', compute_sqnr, 'sqnr')
+ extend_logger_results_with_comparison(
+ results, 'a', 'b', compute_normalized_l2_error, 'l2_error')
+ extend_logger_results_with_comparison(
+ results, 'a', 'b', compute_cosine_similarity,
+ 'cosine_similarity')
def _test_match_activations(
self, m, data, prepared_expected_node_occurrence=None, results_len=0,
@@ -834,6 +843,13 @@
len(act_compare_dict) == results_len,
f"expected len {results_len}, got len {len(act_compare_dict)}")
self.assert_ns_compare_dict_valid(act_compare_dict)
+ extend_logger_results_with_comparison(
+ act_compare_dict, 'a', 'b', compute_sqnr, 'sqnr')
+ extend_logger_results_with_comparison(
+ act_compare_dict, 'a', 'b', compute_normalized_l2_error, 'l2_error')
+ extend_logger_results_with_comparison(
+ act_compare_dict, 'a', 'b', compute_cosine_similarity,
+ 'cosine_similarity')
results.append(act_compare_dict)
return results
@@ -890,6 +906,13 @@
len(act_compare_dict) == results_len,
f"expected len {results_len}, got len {len(act_compare_dict)}")
self.assert_ns_compare_dict_valid(act_compare_dict)
+ extend_logger_results_with_comparison(
+ act_compare_dict, 'a', 'b', compute_sqnr, 'sqnr')
+ extend_logger_results_with_comparison(
+ act_compare_dict, 'a', 'b', compute_normalized_l2_error, 'l2_error')
+ extend_logger_results_with_comparison(
+ act_compare_dict, 'a', 'b', compute_cosine_similarity,
+ 'cosine_similarity')
results.append(act_compare_dict)
return results
@@ -1669,10 +1692,19 @@
results = extract_weights('fp32', mp, 'int8', mq)
extend_logger_results_with_comparison(
results, 'fp32', 'int8', compute_sqnr, 'sqnr_int8_vs_fp32')
+ extend_logger_results_with_comparison(
+ results, 'fp32', 'int8', compute_normalized_l2_error, 'l2_error_int8_vs_fp32')
+ extend_logger_results_with_comparison(
+ results, 'fp32', 'int8', compute_cosine_similarity,
+ 'cosine_similarity_int8_vs_fp32')
for layer_name, layer_results in results.items():
assert 'sqnr_int8_vs_fp32' in \
layer_results['weight']['int8'][0].keys()
+ assert 'l2_error_int8_vs_fp32' in \
+ layer_results['weight']['int8'][0].keys()
+ assert 'cosine_similarity_int8_vs_fp32' in \
+ layer_results['weight']['int8'][0].keys()
@skipIfNoFBGEMM
def test_int8_shadows_fp32_simple(self):
diff --git a/torch/quantization/ns/utils.py b/torch/quantization/ns/utils.py
index ffa3623..ae625db 100644
--- a/torch/quantization/ns/utils.py
+++ b/torch/quantization/ns/utils.py
@@ -389,12 +389,46 @@
fqn = ref_model_results[i]['fqn']
model_results[i]['fqn'] = fqn
+def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f):
+ def inner(*args, **kwargs):
+ a0, a1, *a_other = args
+ if isinstance(a0, tuple) and isinstance(a1, tuple):
+ results = []
+ for el0, el1 in zip(a0, a1):
+ new_args = (el0, el1, *a_other)
+ results.append(inner(*new_args, **kwargs))
+ return results
+
+ elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor):
+ if a0.is_quantized:
+ a0 = a0.dequantize()
+ if a1.is_quantized:
+ a1 = a1.dequantize()
+
+ # for the purposes of this util, only handle floats
+ if a0.dtype != torch.float or a1.dtype != torch.float:
+ return None
+
+ new_args = (a0, a1, *a_other)
+ return f(*new_args, **kwargs)
+ return inner
+
+@maybe_dequantize_first_two_tensor_args_and_handle_tuples
def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
- if x.is_quantized:
- x = x.dequantize()
- if y.is_quantized:
- y = y.dequantize()
Ps = torch.norm(x)
Pn = torch.norm(x - y)
return 20 * torch.log10(Ps / Pn)
+
+@maybe_dequantize_first_two_tensor_args_and_handle_tuples
+def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ return torch.sqrt(((x - y) ** 2).sum() / (x ** 2).sum())
+
+@maybe_dequantize_first_two_tensor_args_and_handle_tuples
+def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ # For convolutions, the shape of the quantized weight has one additional
+ # dimension compared to the shape of the fp32 weight. Match the shapes
+ # to enable cosine similarity comparison.
+ x = x.reshape(1, -1)
+ y = y.reshape(1, -1)
+ return torch.nn.functional.cosine_similarity(x, y)