ns for fx: add utils for l2 error and cosine similarity (#61380)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61380

Adds convenience wrappers for l2 error and cosine similarity
to NS utils.

Test Plan:
```
python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_extend_logger_results_with_comparison
```

Imported from OSS

Reviewed By: hx89

Differential Revision: D29600354

fbshipit-source-id: 670c44a44df7f345884cacf26ed3c885edbe9977
diff --git a/test/quantization/fx/test_numeric_suite_fx.py b/test/quantization/fx/test_numeric_suite_fx.py
index 307d0c1..7e4b024 100644
--- a/test/quantization/fx/test_numeric_suite_fx.py
+++ b/test/quantization/fx/test_numeric_suite_fx.py
@@ -42,6 +42,8 @@
 )
 from torch.quantization.ns.utils import (
     compute_sqnr,
+    compute_normalized_l2_error,
+    compute_cosine_similarity,
 )
 from torch.quantization.ns.mappings import (
     get_node_type_to_io_type_map,
@@ -778,6 +780,13 @@
                     len(results) == results_len,
                     f"expected len {results_len}, got len {len(results)}")
                 self.assert_ns_compare_dict_valid(results)
+                extend_logger_results_with_comparison(
+                    results, 'a', 'b', compute_sqnr, 'sqnr')
+                extend_logger_results_with_comparison(
+                    results, 'a', 'b', compute_normalized_l2_error, 'l2_error')
+                extend_logger_results_with_comparison(
+                    results, 'a', 'b', compute_cosine_similarity,
+                    'cosine_similarity')
 
     def _test_match_activations(
         self, m, data, prepared_expected_node_occurrence=None, results_len=0,
@@ -834,6 +843,13 @@
                 len(act_compare_dict) == results_len,
                 f"expected len {results_len}, got len {len(act_compare_dict)}")
             self.assert_ns_compare_dict_valid(act_compare_dict)
+            extend_logger_results_with_comparison(
+                act_compare_dict, 'a', 'b', compute_sqnr, 'sqnr')
+            extend_logger_results_with_comparison(
+                act_compare_dict, 'a', 'b', compute_normalized_l2_error, 'l2_error')
+            extend_logger_results_with_comparison(
+                act_compare_dict, 'a', 'b', compute_cosine_similarity,
+                'cosine_similarity')
             results.append(act_compare_dict)
         return results
 
@@ -890,6 +906,13 @@
                     len(act_compare_dict) == results_len,
                     f"expected len {results_len}, got len {len(act_compare_dict)}")
             self.assert_ns_compare_dict_valid(act_compare_dict)
+            extend_logger_results_with_comparison(
+                act_compare_dict, 'a', 'b', compute_sqnr, 'sqnr')
+            extend_logger_results_with_comparison(
+                act_compare_dict, 'a', 'b', compute_normalized_l2_error, 'l2_error')
+            extend_logger_results_with_comparison(
+                act_compare_dict, 'a', 'b', compute_cosine_similarity,
+                'cosine_similarity')
             results.append(act_compare_dict)
         return results
 
@@ -1669,10 +1692,19 @@
         results = extract_weights('fp32', mp, 'int8', mq)
         extend_logger_results_with_comparison(
             results, 'fp32', 'int8', compute_sqnr, 'sqnr_int8_vs_fp32')
+        extend_logger_results_with_comparison(
+            results, 'fp32', 'int8', compute_normalized_l2_error, 'l2_error_int8_vs_fp32')
+        extend_logger_results_with_comparison(
+            results, 'fp32', 'int8', compute_cosine_similarity,
+            'cosine_similarity_int8_vs_fp32')
 
         for layer_name, layer_results in results.items():
             assert 'sqnr_int8_vs_fp32' in \
                 layer_results['weight']['int8'][0].keys()
+            assert 'l2_error_int8_vs_fp32' in \
+                layer_results['weight']['int8'][0].keys()
+            assert 'cosine_similarity_int8_vs_fp32' in \
+                layer_results['weight']['int8'][0].keys()
 
     @skipIfNoFBGEMM
     def test_int8_shadows_fp32_simple(self):
diff --git a/torch/quantization/ns/utils.py b/torch/quantization/ns/utils.py
index ffa3623..ae625db 100644
--- a/torch/quantization/ns/utils.py
+++ b/torch/quantization/ns/utils.py
@@ -389,12 +389,46 @@
                         fqn = ref_model_results[i]['fqn']
                         model_results[i]['fqn'] = fqn
 
+def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f):
+    def inner(*args, **kwargs):
+        a0, a1, *a_other = args
 
+        if isinstance(a0, tuple) and isinstance(a1, tuple):
+            results = []
+            for el0, el1 in zip(a0, a1):
+                new_args = (el0, el1, *a_other)
+                results.append(inner(*new_args, **kwargs))
+            return results
+
+        elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor):
+            if a0.is_quantized:
+                a0 = a0.dequantize()
+            if a1.is_quantized:
+                a1 = a1.dequantize()
+
+        # for the purposes of this util, only handle floats
+        if a0.dtype != torch.float or a1.dtype != torch.float:
+            return None
+
+        new_args = (a0, a1, *a_other)
+        return f(*new_args, **kwargs)
+    return inner
+
+@maybe_dequantize_first_two_tensor_args_and_handle_tuples
 def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
-    if x.is_quantized:
-        x = x.dequantize()
-    if y.is_quantized:
-        y = y.dequantize()
     Ps = torch.norm(x)
     Pn = torch.norm(x - y)
     return 20 * torch.log10(Ps / Pn)
+
+@maybe_dequantize_first_two_tensor_args_and_handle_tuples
+def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+    return torch.sqrt(((x - y) ** 2).sum() / (x ** 2).sum())
+
+@maybe_dequantize_first_two_tensor_args_and_handle_tuples
+def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+    # For convolutions, the shape of the quantized weight has one additional
+    # dimension compared to the shape of the fp32 weight. Match the shapes
+    # to enable cosine similarity comparison.
+    x = x.reshape(1, -1)
+    y = y.reshape(1, -1)
+    return torch.nn.functional.cosine_similarity(x, y)